AstrAI/astrai/config/param_config.py

115 lines
3.5 KiB
Python

import torch.nn as nn
import safetensors.torch as st
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Optional, Self, Union
from pathlib import Path
from astrai.data.tokenizer import BpeTokenizer
from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
@contextmanager
def disable_random_init(enable: bool = True):
init_functions = [
"xavier_normal_",
"xavier_uniform_",
"kaiming_normal_",
"kaiming_uniform_",
"zeros_",
"ones_",
"constant_",
"normal_",
"uniform_",
]
original_funcs = {}
for name in init_functions:
if enable and hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try:
yield
finally:
if enable:
for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
@dataclass
class BaseModelIO:
"""Base class for model I/O operations."""
model: nn.Module = field(
default_factory=nn.Identity, metadata={"help": "Transformer model."}
)
tokenizer: BpeTokenizer = field(
default_factory=BpeTokenizer, metadata={"help": "Tokenizer for the model."}
)
config: ModelConfig = field(
default_factory=ModelConfig,
metadata={"help": "Transformer model configuration."},
)
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
"""Get standardized file paths for model components."""
dir_path = Path(directory)
return {
"model": dir_path / "model.safetensors",
"config": dir_path / "config.json",
"tokenizer": dir_path / "tokenizer.json",
}
def save_components(self, save_dir: Union[str, Path]):
"""Save core model components."""
paths = self._get_file_paths(save_dir)
paths["model"].parent.mkdir(parents=True, exist_ok=True)
if self.model is not None:
st.save_file(self.model.state_dict(), str(paths["model"]))
self.config.save(str(paths["config"]))
self.tokenizer.save(str(paths["tokenizer"]))
def load_components(
self, load_dir: Union[str, Path], disable_init: bool = False
) -> Self:
"""Load core model components."""
paths = self._get_file_paths(load_dir)
self.config.load(str(paths["config"]))
self.tokenizer.load(str(paths["tokenizer"]))
if isinstance(self.model, nn.Identity):
with disable_random_init(enable=disable_init):
self.model = Transformer(self.config)
if paths["model"].exists():
state_dict = st.load_file(str(paths["model"]))
self.model.load_state_dict(state_dict)
return self
def to(self, *args, **kwargs) -> "BaseModelIO":
"""Move model to device."""
if self.model is not None:
self.model.to(*args, **kwargs)
return self
@dataclass
class ModelParameter(BaseModelIO):
"""Container for model parameters with serialization capabilities."""
@classmethod
def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]):
instance.save_components(save_dir)
@classmethod
def load(
cls, load_dir: Union[str, Path], disable_init: bool = False
) -> "ModelParameter":
instance = cls()
return instance.load_components(load_dir, disable_init=disable_init)