refactor(config): 重命名 TransformerConfig 为 ModelConfig
This commit is contained in:
parent
66a551217e
commit
7e5ecf3b7d
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from typing import Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from khaosz.model.transformer import TransformerConfig, Transformer
|
||||
from khaosz.model.transformer import ModelConfig, Transformer
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -15,7 +15,7 @@ class BenchmarkResult:
|
|||
class GenerationBenchmark:
|
||||
def __init__(
|
||||
self,
|
||||
config: TransformerConfig,
|
||||
config: ModelConfig,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.float16
|
||||
):
|
||||
|
|
@ -173,7 +173,7 @@ def print_benchmark_result(result: BenchmarkResult):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = TransformerConfig(
|
||||
config = ModelConfig(
|
||||
vocab_size=10000,
|
||||
n_dim=1536,
|
||||
n_head=24,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ __author__ = "ViperEkura"
|
|||
|
||||
from khaosz.api import Khaosz
|
||||
from khaosz.config import (
|
||||
TransformerConfig,
|
||||
ModelConfig,
|
||||
ParameterLoader,
|
||||
TrainConfig,
|
||||
)
|
||||
|
|
@ -41,7 +41,7 @@ __all__ = [
|
|||
"SemanticTextSplitter",
|
||||
"PriorityTextSplitter",
|
||||
|
||||
"TransformerConfig",
|
||||
"ModelConfig",
|
||||
"ParameterLoader",
|
||||
"TrainConfig",
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from khaosz.config.model_config import TransformerConfig
|
||||
from khaosz.config.model_config import ModelConfig
|
||||
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
|
||||
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
|
||||
from khaosz.config.train_config import TrainConfig
|
||||
|
|
@ -9,7 +9,7 @@ __all__ = [
|
|||
"ModelParameter",
|
||||
"Checkpoint",
|
||||
"ParameterLoader",
|
||||
"TransformerConfig",
|
||||
"ModelConfig",
|
||||
"TrainConfig",
|
||||
|
||||
"ScheduleConfig",
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import json
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional, Self
|
||||
from typing import Any, Dict, Optional, Self
|
||||
|
||||
@dataclass
|
||||
class TransformerConfig:
|
||||
class ModelConfig:
|
||||
# basic config
|
||||
vocab_size: Optional[int] = None
|
||||
n_dim: Optional[int] = None
|
||||
|
|
@ -21,7 +21,7 @@ class TransformerConfig:
|
|||
|
||||
def load(self, config_path: str) -> Self:
|
||||
with open(config_path, 'r') as f:
|
||||
config: dict = json.load(f)
|
||||
config: Dict[str, Any] = json.load(f)
|
||||
for key, value in config.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Self, Union
|
|||
from pathlib import Path
|
||||
|
||||
from khaosz.data.tokenizer import BpeTokenizer
|
||||
from khaosz.config.model_config import TransformerConfig
|
||||
from khaosz.config.model_config import ModelConfig
|
||||
from khaosz.model.transformer import Transformer
|
||||
|
||||
|
||||
|
|
@ -20,11 +20,11 @@ class BaseModelIO:
|
|||
self,
|
||||
model: Optional[nn.Module] = None,
|
||||
tokenizer: Optional[BpeTokenizer] = None,
|
||||
config: Optional[TransformerConfig] = None
|
||||
config: Optional[ModelConfig] = None
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer or BpeTokenizer()
|
||||
self.config = config or TransformerConfig()
|
||||
self.config = config or ModelConfig()
|
||||
|
||||
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
||||
"""Get standardized file paths for model components."""
|
||||
|
|
@ -79,8 +79,8 @@ class ModelParameter(BaseModelIO):
|
|||
default_factory=BpeTokenizer,
|
||||
metadata={"help": "Tokenizer for the model."}
|
||||
)
|
||||
config: TransformerConfig = field(
|
||||
default_factory=TransformerConfig,
|
||||
config: ModelConfig = field(
|
||||
default_factory=ModelConfig,
|
||||
metadata={"help": "Transformer model configuration."}
|
||||
)
|
||||
|
||||
|
|
@ -103,8 +103,8 @@ class Checkpoint(BaseModelIO):
|
|||
default_factory=BpeTokenizer,
|
||||
metadata={"help": "Tokenizer for the model."}
|
||||
)
|
||||
config: TransformerConfig = field(
|
||||
default_factory=TransformerConfig,
|
||||
config: ModelConfig = field(
|
||||
default_factory=ModelConfig,
|
||||
metadata={"help": "Transformer model configuration."}
|
||||
)
|
||||
optimizer_state: Dict[str, Any] = field(
|
||||
|
|
@ -230,7 +230,7 @@ class ParameterLoader:
|
|||
def create_checkpoint(
|
||||
model: nn.Module,
|
||||
tokenizer: BpeTokenizer,
|
||||
config: TransformerConfig,
|
||||
config: ModelConfig,
|
||||
loss_list: Optional[list[float]] = None,
|
||||
optimizer: Optional[optim.Optimizer] = None,
|
||||
) -> Checkpoint:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from typing import Any, Callable, List, Tuple, Union, Optional, Self
|
||||
from khaosz.config import ModelParameter, TransformerConfig
|
||||
from khaosz.config import ModelParameter, ModelConfig
|
||||
|
||||
|
||||
def apply_sampling_strategies(
|
||||
|
|
@ -187,7 +187,7 @@ class EmbeddingEncoderCore:
|
|||
class KVCacheManager:
|
||||
def __init__(
|
||||
self,
|
||||
config: TransformerConfig,
|
||||
config: ModelConfig,
|
||||
batch_size: int,
|
||||
device: torch.device = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
|||
|
||||
from torch import Tensor
|
||||
from typing import Any, Mapping, Optional, Tuple
|
||||
from khaosz.config.model_config import TransformerConfig
|
||||
from khaosz.config.model_config import ModelConfig
|
||||
from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, RotaryEmbedding
|
||||
|
||||
|
||||
|
|
@ -61,7 +61,7 @@ def process_attention_mask(
|
|||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, config: TransformerConfig):
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rotary_embeding = RotaryEmbedding(config.n_dim // config.n_head, config.m_len)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import pytest
|
|||
import matplotlib
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from khaosz.config.model_config import TransformerConfig
|
||||
from khaosz.config.model_config import ModelConfig
|
||||
from khaosz.data.tokenizer import BpeTokenizer
|
||||
from khaosz.model.transformer import Transformer
|
||||
|
||||
|
|
@ -102,7 +102,7 @@ def base_test_env(request: pytest.FixtureRequest):
|
|||
with open(config_path, 'w') as f:
|
||||
json.dump(config, f)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
transformer_config = TransformerConfig().load(config_path)
|
||||
transformer_config = ModelConfig().load(config_path)
|
||||
model = Transformer(transformer_config).to(device=device)
|
||||
tokenizer = BpeTokenizer()
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def test_env(request: pytest.FixtureRequest):
|
|||
tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
|
||||
tokenizer.save(tokenizer_path)
|
||||
|
||||
transformer_config = TransformerConfig().load(config_path)
|
||||
transformer_config = ModelConfig().load(config_path)
|
||||
model = Transformer(transformer_config)
|
||||
st.save_file(model.state_dict(), model_path)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue