refactor(config): 重命名 TransformerConfig 为 ModelConfig

This commit is contained in:
ViperEkura 2025-11-07 07:31:12 +08:00
parent 66a551217e
commit 7e5ecf3b7d
9 changed files with 25 additions and 25 deletions

View File

@ -1,7 +1,7 @@
import torch
from typing import Dict, Any
from dataclasses import dataclass
from khaosz.model.transformer import TransformerConfig, Transformer
from khaosz.model.transformer import ModelConfig, Transformer
@dataclass
@ -15,7 +15,7 @@ class BenchmarkResult:
class GenerationBenchmark:
def __init__(
self,
config: TransformerConfig,
config: ModelConfig,
device: str = "cuda",
dtype: torch.dtype = torch.float16
):
@ -173,7 +173,7 @@ def print_benchmark_result(result: BenchmarkResult):
if __name__ == "__main__":
config = TransformerConfig(
config = ModelConfig(
vocab_size=10000,
n_dim=1536,
n_head=24,

View File

@ -3,7 +3,7 @@ __author__ = "ViperEkura"
from khaosz.api import Khaosz
from khaosz.config import (
TransformerConfig,
ModelConfig,
ParameterLoader,
TrainConfig,
)
@ -41,7 +41,7 @@ __all__ = [
"SemanticTextSplitter",
"PriorityTextSplitter",
"TransformerConfig",
"ModelConfig",
"ParameterLoader",
"TrainConfig",

View File

@ -1,4 +1,4 @@
from khaosz.config.model_config import TransformerConfig
from khaosz.config.model_config import ModelConfig
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
from khaosz.config.train_config import TrainConfig
@ -9,7 +9,7 @@ __all__ = [
"ModelParameter",
"Checkpoint",
"ParameterLoader",
"TransformerConfig",
"ModelConfig",
"TrainConfig",
"ScheduleConfig",

View File

@ -1,10 +1,10 @@
import json
from dataclasses import asdict, dataclass
from typing import Optional, Self
from typing import Any, Dict, Optional, Self
@dataclass
class TransformerConfig:
class ModelConfig:
# basic config
vocab_size: Optional[int] = None
n_dim: Optional[int] = None
@ -21,7 +21,7 @@ class TransformerConfig:
def load(self, config_path: str) -> Self:
with open(config_path, 'r') as f:
config: dict = json.load(f)
config: Dict[str, Any] = json.load(f)
for key, value in config.items():
if hasattr(self, key):
setattr(self, key, value)

View File

@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Self, Union
from pathlib import Path
from khaosz.data.tokenizer import BpeTokenizer
from khaosz.config.model_config import TransformerConfig
from khaosz.config.model_config import ModelConfig
from khaosz.model.transformer import Transformer
@ -20,11 +20,11 @@ class BaseModelIO:
self,
model: Optional[nn.Module] = None,
tokenizer: Optional[BpeTokenizer] = None,
config: Optional[TransformerConfig] = None
config: Optional[ModelConfig] = None
):
self.model = model
self.tokenizer = tokenizer or BpeTokenizer()
self.config = config or TransformerConfig()
self.config = config or ModelConfig()
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
"""Get standardized file paths for model components."""
@ -79,8 +79,8 @@ class ModelParameter(BaseModelIO):
default_factory=BpeTokenizer,
metadata={"help": "Tokenizer for the model."}
)
config: TransformerConfig = field(
default_factory=TransformerConfig,
config: ModelConfig = field(
default_factory=ModelConfig,
metadata={"help": "Transformer model configuration."}
)
@ -103,8 +103,8 @@ class Checkpoint(BaseModelIO):
default_factory=BpeTokenizer,
metadata={"help": "Tokenizer for the model."}
)
config: TransformerConfig = field(
default_factory=TransformerConfig,
config: ModelConfig = field(
default_factory=ModelConfig,
metadata={"help": "Transformer model configuration."}
)
optimizer_state: Dict[str, Any] = field(
@ -230,7 +230,7 @@ class ParameterLoader:
def create_checkpoint(
model: nn.Module,
tokenizer: BpeTokenizer,
config: TransformerConfig,
config: ModelConfig,
loss_list: Optional[list[float]] = None,
optimizer: Optional[optim.Optimizer] = None,
) -> Checkpoint:

View File

@ -1,7 +1,7 @@
import torch
from torch import Tensor
from typing import Any, Callable, List, Tuple, Union, Optional, Self
from khaosz.config import ModelParameter, TransformerConfig
from khaosz.config import ModelParameter, ModelConfig
def apply_sampling_strategies(
@ -187,7 +187,7 @@ class EmbeddingEncoderCore:
class KVCacheManager:
def __init__(
self,
config: TransformerConfig,
config: ModelConfig,
batch_size: int,
device: torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16

View File

@ -3,7 +3,7 @@ import torch.nn as nn
from torch import Tensor
from typing import Any, Mapping, Optional, Tuple
from khaosz.config.model_config import TransformerConfig
from khaosz.config.model_config import ModelConfig
from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, RotaryEmbedding
@ -61,7 +61,7 @@ def process_attention_mask(
class Transformer(nn.Module):
def __init__(self, config: TransformerConfig):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.rotary_embeding = RotaryEmbedding(config.n_dim // config.n_head, config.m_len)

View File

@ -9,7 +9,7 @@ import pytest
import matplotlib
from torch.utils.data import Dataset
from khaosz.config.model_config import TransformerConfig
from khaosz.config.model_config import ModelConfig
from khaosz.data.tokenizer import BpeTokenizer
from khaosz.model.transformer import Transformer
@ -102,7 +102,7 @@ def base_test_env(request: pytest.FixtureRequest):
with open(config_path, 'w') as f:
json.dump(config, f)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformer_config = TransformerConfig().load(config_path)
transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config).to(device=device)
tokenizer = BpeTokenizer()

View File

@ -38,7 +38,7 @@ def test_env(request: pytest.FixtureRequest):
tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
tokenizer.save(tokenizer_path)
transformer_config = TransformerConfig().load(config_path)
transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config)
st.save_file(model.state_dict(), model_path)