refactor(config): 重命名 TransformerConfig 为 ModelConfig
This commit is contained in:
parent
66a551217e
commit
7e5ecf3b7d
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from khaosz.model.transformer import TransformerConfig, Transformer
|
from khaosz.model.transformer import ModelConfig, Transformer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -15,7 +15,7 @@ class BenchmarkResult:
|
||||||
class GenerationBenchmark:
|
class GenerationBenchmark:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: TransformerConfig,
|
config: ModelConfig,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.float16
|
dtype: torch.dtype = torch.float16
|
||||||
):
|
):
|
||||||
|
|
@ -173,7 +173,7 @@ def print_benchmark_result(result: BenchmarkResult):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
config = TransformerConfig(
|
config = ModelConfig(
|
||||||
vocab_size=10000,
|
vocab_size=10000,
|
||||||
n_dim=1536,
|
n_dim=1536,
|
||||||
n_head=24,
|
n_head=24,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ __author__ = "ViperEkura"
|
||||||
|
|
||||||
from khaosz.api import Khaosz
|
from khaosz.api import Khaosz
|
||||||
from khaosz.config import (
|
from khaosz.config import (
|
||||||
TransformerConfig,
|
ModelConfig,
|
||||||
ParameterLoader,
|
ParameterLoader,
|
||||||
TrainConfig,
|
TrainConfig,
|
||||||
)
|
)
|
||||||
|
|
@ -41,7 +41,7 @@ __all__ = [
|
||||||
"SemanticTextSplitter",
|
"SemanticTextSplitter",
|
||||||
"PriorityTextSplitter",
|
"PriorityTextSplitter",
|
||||||
|
|
||||||
"TransformerConfig",
|
"ModelConfig",
|
||||||
"ParameterLoader",
|
"ParameterLoader",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from khaosz.config.model_config import TransformerConfig
|
from khaosz.config.model_config import ModelConfig
|
||||||
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
|
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
|
||||||
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
|
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
|
||||||
from khaosz.config.train_config import TrainConfig
|
from khaosz.config.train_config import TrainConfig
|
||||||
|
|
@ -9,7 +9,7 @@ __all__ = [
|
||||||
"ModelParameter",
|
"ModelParameter",
|
||||||
"Checkpoint",
|
"Checkpoint",
|
||||||
"ParameterLoader",
|
"ParameterLoader",
|
||||||
"TransformerConfig",
|
"ModelConfig",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
|
|
||||||
"ScheduleConfig",
|
"ScheduleConfig",
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Optional, Self
|
from typing import Any, Dict, Optional, Self
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TransformerConfig:
|
class ModelConfig:
|
||||||
# basic config
|
# basic config
|
||||||
vocab_size: Optional[int] = None
|
vocab_size: Optional[int] = None
|
||||||
n_dim: Optional[int] = None
|
n_dim: Optional[int] = None
|
||||||
|
|
@ -21,7 +21,7 @@ class TransformerConfig:
|
||||||
|
|
||||||
def load(self, config_path: str) -> Self:
|
def load(self, config_path: str) -> Self:
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, 'r') as f:
|
||||||
config: dict = json.load(f)
|
config: Dict[str, Any] = json.load(f)
|
||||||
for key, value in config.items():
|
for key, value in config.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Self, Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from khaosz.data.tokenizer import BpeTokenizer
|
from khaosz.data.tokenizer import BpeTokenizer
|
||||||
from khaosz.config.model_config import TransformerConfig
|
from khaosz.config.model_config import ModelConfig
|
||||||
from khaosz.model.transformer import Transformer
|
from khaosz.model.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -20,11 +20,11 @@ class BaseModelIO:
|
||||||
self,
|
self,
|
||||||
model: Optional[nn.Module] = None,
|
model: Optional[nn.Module] = None,
|
||||||
tokenizer: Optional[BpeTokenizer] = None,
|
tokenizer: Optional[BpeTokenizer] = None,
|
||||||
config: Optional[TransformerConfig] = None
|
config: Optional[ModelConfig] = None
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer or BpeTokenizer()
|
self.tokenizer = tokenizer or BpeTokenizer()
|
||||||
self.config = config or TransformerConfig()
|
self.config = config or ModelConfig()
|
||||||
|
|
||||||
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
||||||
"""Get standardized file paths for model components."""
|
"""Get standardized file paths for model components."""
|
||||||
|
|
@ -79,8 +79,8 @@ class ModelParameter(BaseModelIO):
|
||||||
default_factory=BpeTokenizer,
|
default_factory=BpeTokenizer,
|
||||||
metadata={"help": "Tokenizer for the model."}
|
metadata={"help": "Tokenizer for the model."}
|
||||||
)
|
)
|
||||||
config: TransformerConfig = field(
|
config: ModelConfig = field(
|
||||||
default_factory=TransformerConfig,
|
default_factory=ModelConfig,
|
||||||
metadata={"help": "Transformer model configuration."}
|
metadata={"help": "Transformer model configuration."}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -103,8 +103,8 @@ class Checkpoint(BaseModelIO):
|
||||||
default_factory=BpeTokenizer,
|
default_factory=BpeTokenizer,
|
||||||
metadata={"help": "Tokenizer for the model."}
|
metadata={"help": "Tokenizer for the model."}
|
||||||
)
|
)
|
||||||
config: TransformerConfig = field(
|
config: ModelConfig = field(
|
||||||
default_factory=TransformerConfig,
|
default_factory=ModelConfig,
|
||||||
metadata={"help": "Transformer model configuration."}
|
metadata={"help": "Transformer model configuration."}
|
||||||
)
|
)
|
||||||
optimizer_state: Dict[str, Any] = field(
|
optimizer_state: Dict[str, Any] = field(
|
||||||
|
|
@ -230,7 +230,7 @@ class ParameterLoader:
|
||||||
def create_checkpoint(
|
def create_checkpoint(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: BpeTokenizer,
|
tokenizer: BpeTokenizer,
|
||||||
config: TransformerConfig,
|
config: ModelConfig,
|
||||||
loss_list: Optional[list[float]] = None,
|
loss_list: Optional[list[float]] = None,
|
||||||
optimizer: Optional[optim.Optimizer] = None,
|
optimizer: Optional[optim.Optimizer] = None,
|
||||||
) -> Checkpoint:
|
) -> Checkpoint:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Any, Callable, List, Tuple, Union, Optional, Self
|
from typing import Any, Callable, List, Tuple, Union, Optional, Self
|
||||||
from khaosz.config import ModelParameter, TransformerConfig
|
from khaosz.config import ModelParameter, ModelConfig
|
||||||
|
|
||||||
|
|
||||||
def apply_sampling_strategies(
|
def apply_sampling_strategies(
|
||||||
|
|
@ -187,7 +187,7 @@ class EmbeddingEncoderCore:
|
||||||
class KVCacheManager:
|
class KVCacheManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: TransformerConfig,
|
config: ModelConfig,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
device: torch.device = "cuda",
|
device: torch.device = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16
|
dtype: torch.dtype = torch.bfloat16
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Any, Mapping, Optional, Tuple
|
from typing import Any, Mapping, Optional, Tuple
|
||||||
from khaosz.config.model_config import TransformerConfig
|
from khaosz.config.model_config import ModelConfig
|
||||||
from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, RotaryEmbedding
|
from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, RotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -61,7 +61,7 @@ def process_attention_mask(
|
||||||
|
|
||||||
|
|
||||||
class Transformer(nn.Module):
|
class Transformer(nn.Module):
|
||||||
def __init__(self, config: TransformerConfig):
|
def __init__(self, config: ModelConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rotary_embeding = RotaryEmbedding(config.n_dim // config.n_head, config.m_len)
|
self.rotary_embeding = RotaryEmbedding(config.n_dim // config.n_head, config.m_len)
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import pytest
|
||||||
import matplotlib
|
import matplotlib
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from khaosz.config.model_config import TransformerConfig
|
from khaosz.config.model_config import ModelConfig
|
||||||
from khaosz.data.tokenizer import BpeTokenizer
|
from khaosz.data.tokenizer import BpeTokenizer
|
||||||
from khaosz.model.transformer import Transformer
|
from khaosz.model.transformer import Transformer
|
||||||
|
|
||||||
|
|
@ -102,7 +102,7 @@ def base_test_env(request: pytest.FixtureRequest):
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, 'w') as f:
|
||||||
json.dump(config, f)
|
json.dump(config, f)
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
transformer_config = TransformerConfig().load(config_path)
|
transformer_config = ModelConfig().load(config_path)
|
||||||
model = Transformer(transformer_config).to(device=device)
|
model = Transformer(transformer_config).to(device=device)
|
||||||
tokenizer = BpeTokenizer()
|
tokenizer = BpeTokenizer()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ def test_env(request: pytest.FixtureRequest):
|
||||||
tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
|
tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
|
||||||
tokenizer.save(tokenizer_path)
|
tokenizer.save(tokenizer_path)
|
||||||
|
|
||||||
transformer_config = TransformerConfig().load(config_path)
|
transformer_config = ModelConfig().load(config_path)
|
||||||
model = Transformer(transformer_config)
|
model = Transformer(transformer_config)
|
||||||
st.save_file(model.state_dict(), model_path)
|
st.save_file(model.state_dict(), model_path)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue