AstrAI/khaosz/config/model_config.py

43 lines
1.1 KiB
Python

import json
from dataclasses import asdict, dataclass
from typing import Optional, Self
@dataclass
class ModelConfig:
# basic config
vocab_size: Optional[int] = None
dim: Optional[int] = None
n_layers: Optional[int] = None
norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None
tie_weight: Optional[bool] = None
# RoPE
max_len: Optional[int] = None
rope_theta: Optional[float] = None
# GQA
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
def load(self, config_path: str) -> Self:
config = {}
with open(config_path, "r") as f:
config.update(json.load(f))
for key, value in config.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def save(self, config_path: str):
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)