43 lines
1.1 KiB
Python
43 lines
1.1 KiB
Python
import json
|
|
|
|
from dataclasses import asdict, dataclass
|
|
from typing import Optional, Self
|
|
|
|
|
|
@dataclass
|
|
class ModelConfig:
|
|
# basic config
|
|
vocab_size: Optional[int] = None
|
|
dim: Optional[int] = None
|
|
|
|
n_layers: Optional[int] = None
|
|
norm_eps: Optional[float] = None
|
|
dim_ffn: Optional[int] = None
|
|
tie_weight: Optional[bool] = None
|
|
|
|
# RoPE
|
|
max_len: Optional[int] = None
|
|
rope_theta: Optional[float] = None
|
|
|
|
# GQA
|
|
n_heads: Optional[int] = None
|
|
n_kv_heads: Optional[int] = None
|
|
use_qk_norm: Optional[bool] = None
|
|
use_gated_attention: Optional[bool] = None
|
|
|
|
def load(self, config_path: str) -> Self:
|
|
config = {}
|
|
with open(config_path, "r") as f:
|
|
config.update(json.load(f))
|
|
|
|
for key, value in config.items():
|
|
if hasattr(self, key):
|
|
setattr(self, key, value)
|
|
|
|
return self
|
|
|
|
def save(self, config_path: str):
|
|
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
|
|
with open(config_path, "w") as f:
|
|
json.dump(config_dict, f, indent=4)
|