refactor(config): 重命名 TransformerConfig 为 ModelConfig

This commit is contained in:
ViperEkura 2025-11-07 07:31:12 +08:00
parent 66a551217e
commit 7e5ecf3b7d
9 changed files with 25 additions and 25 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
from typing import Dict, Any from typing import Dict, Any
from dataclasses import dataclass from dataclasses import dataclass
from khaosz.model.transformer import TransformerConfig, Transformer from khaosz.model.transformer import ModelConfig, Transformer
@dataclass @dataclass
@ -15,7 +15,7 @@ class BenchmarkResult:
class GenerationBenchmark: class GenerationBenchmark:
def __init__( def __init__(
self, self,
config: TransformerConfig, config: ModelConfig,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.float16 dtype: torch.dtype = torch.float16
): ):
@ -173,7 +173,7 @@ def print_benchmark_result(result: BenchmarkResult):
if __name__ == "__main__": if __name__ == "__main__":
config = TransformerConfig( config = ModelConfig(
vocab_size=10000, vocab_size=10000,
n_dim=1536, n_dim=1536,
n_head=24, n_head=24,

View File

@ -3,7 +3,7 @@ __author__ = "ViperEkura"
from khaosz.api import Khaosz from khaosz.api import Khaosz
from khaosz.config import ( from khaosz.config import (
TransformerConfig, ModelConfig,
ParameterLoader, ParameterLoader,
TrainConfig, TrainConfig,
) )
@ -41,7 +41,7 @@ __all__ = [
"SemanticTextSplitter", "SemanticTextSplitter",
"PriorityTextSplitter", "PriorityTextSplitter",
"TransformerConfig", "ModelConfig",
"ParameterLoader", "ParameterLoader",
"TrainConfig", "TrainConfig",

View File

@ -1,4 +1,4 @@
from khaosz.config.model_config import TransformerConfig from khaosz.config.model_config import ModelConfig
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
from khaosz.config.train_config import TrainConfig from khaosz.config.train_config import TrainConfig
@ -9,7 +9,7 @@ __all__ = [
"ModelParameter", "ModelParameter",
"Checkpoint", "Checkpoint",
"ParameterLoader", "ParameterLoader",
"TransformerConfig", "ModelConfig",
"TrainConfig", "TrainConfig",
"ScheduleConfig", "ScheduleConfig",

View File

@ -1,10 +1,10 @@
import json import json
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Optional, Self from typing import Any, Dict, Optional, Self
@dataclass @dataclass
class TransformerConfig: class ModelConfig:
# basic config # basic config
vocab_size: Optional[int] = None vocab_size: Optional[int] = None
n_dim: Optional[int] = None n_dim: Optional[int] = None
@ -21,7 +21,7 @@ class TransformerConfig:
def load(self, config_path: str) -> Self: def load(self, config_path: str) -> Self:
with open(config_path, 'r') as f: with open(config_path, 'r') as f:
config: dict = json.load(f) config: Dict[str, Any] = json.load(f)
for key, value in config.items(): for key, value in config.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)

View File

@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Self, Union
from pathlib import Path from pathlib import Path
from khaosz.data.tokenizer import BpeTokenizer from khaosz.data.tokenizer import BpeTokenizer
from khaosz.config.model_config import TransformerConfig from khaosz.config.model_config import ModelConfig
from khaosz.model.transformer import Transformer from khaosz.model.transformer import Transformer
@ -20,11 +20,11 @@ class BaseModelIO:
self, self,
model: Optional[nn.Module] = None, model: Optional[nn.Module] = None,
tokenizer: Optional[BpeTokenizer] = None, tokenizer: Optional[BpeTokenizer] = None,
config: Optional[TransformerConfig] = None config: Optional[ModelConfig] = None
): ):
self.model = model self.model = model
self.tokenizer = tokenizer or BpeTokenizer() self.tokenizer = tokenizer or BpeTokenizer()
self.config = config or TransformerConfig() self.config = config or ModelConfig()
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]: def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
"""Get standardized file paths for model components.""" """Get standardized file paths for model components."""
@ -79,8 +79,8 @@ class ModelParameter(BaseModelIO):
default_factory=BpeTokenizer, default_factory=BpeTokenizer,
metadata={"help": "Tokenizer for the model."} metadata={"help": "Tokenizer for the model."}
) )
config: TransformerConfig = field( config: ModelConfig = field(
default_factory=TransformerConfig, default_factory=ModelConfig,
metadata={"help": "Transformer model configuration."} metadata={"help": "Transformer model configuration."}
) )
@ -103,8 +103,8 @@ class Checkpoint(BaseModelIO):
default_factory=BpeTokenizer, default_factory=BpeTokenizer,
metadata={"help": "Tokenizer for the model."} metadata={"help": "Tokenizer for the model."}
) )
config: TransformerConfig = field( config: ModelConfig = field(
default_factory=TransformerConfig, default_factory=ModelConfig,
metadata={"help": "Transformer model configuration."} metadata={"help": "Transformer model configuration."}
) )
optimizer_state: Dict[str, Any] = field( optimizer_state: Dict[str, Any] = field(
@ -230,7 +230,7 @@ class ParameterLoader:
def create_checkpoint( def create_checkpoint(
model: nn.Module, model: nn.Module,
tokenizer: BpeTokenizer, tokenizer: BpeTokenizer,
config: TransformerConfig, config: ModelConfig,
loss_list: Optional[list[float]] = None, loss_list: Optional[list[float]] = None,
optimizer: Optional[optim.Optimizer] = None, optimizer: Optional[optim.Optimizer] = None,
) -> Checkpoint: ) -> Checkpoint:

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import Tensor from torch import Tensor
from typing import Any, Callable, List, Tuple, Union, Optional, Self from typing import Any, Callable, List, Tuple, Union, Optional, Self
from khaosz.config import ModelParameter, TransformerConfig from khaosz.config import ModelParameter, ModelConfig
def apply_sampling_strategies( def apply_sampling_strategies(
@ -187,7 +187,7 @@ class EmbeddingEncoderCore:
class KVCacheManager: class KVCacheManager:
def __init__( def __init__(
self, self,
config: TransformerConfig, config: ModelConfig,
batch_size: int, batch_size: int,
device: torch.device = "cuda", device: torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16 dtype: torch.dtype = torch.bfloat16

View File

@ -3,7 +3,7 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from typing import Any, Mapping, Optional, Tuple from typing import Any, Mapping, Optional, Tuple
from khaosz.config.model_config import TransformerConfig from khaosz.config.model_config import ModelConfig
from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, RotaryEmbedding from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, RotaryEmbedding
@ -61,7 +61,7 @@ def process_attention_mask(
class Transformer(nn.Module): class Transformer(nn.Module):
def __init__(self, config: TransformerConfig): def __init__(self, config: ModelConfig):
super().__init__() super().__init__()
self.config = config self.config = config
self.rotary_embeding = RotaryEmbedding(config.n_dim // config.n_head, config.m_len) self.rotary_embeding = RotaryEmbedding(config.n_dim // config.n_head, config.m_len)

View File

@ -9,7 +9,7 @@ import pytest
import matplotlib import matplotlib
from torch.utils.data import Dataset from torch.utils.data import Dataset
from khaosz.config.model_config import TransformerConfig from khaosz.config.model_config import ModelConfig
from khaosz.data.tokenizer import BpeTokenizer from khaosz.data.tokenizer import BpeTokenizer
from khaosz.model.transformer import Transformer from khaosz.model.transformer import Transformer
@ -102,7 +102,7 @@ def base_test_env(request: pytest.FixtureRequest):
with open(config_path, 'w') as f: with open(config_path, 'w') as f:
json.dump(config, f) json.dump(config, f)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformer_config = TransformerConfig().load(config_path) transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config).to(device=device) model = Transformer(transformer_config).to(device=device)
tokenizer = BpeTokenizer() tokenizer = BpeTokenizer()

View File

@ -38,7 +38,7 @@ def test_env(request: pytest.FixtureRequest):
tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1) tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
tokenizer.save(tokenizer_path) tokenizer.save(tokenizer_path)
transformer_config = TransformerConfig().load(config_path) transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config) model = Transformer(transformer_config)
st.save_file(model.state_dict(), model_path) st.save_file(model.state_dict(), model_path)