diff --git a/khaosz/config/param_config.py b/khaosz/config/param_config.py index 6e6cbbd..136d8bd 100644 --- a/khaosz/config/param_config.py +++ b/khaosz/config/param_config.py @@ -12,19 +12,22 @@ from khaosz.data.tokenizer import BpeTokenizer from khaosz.config.model_config import ModelConfig from khaosz.model.transformer import Transformer - +@dataclass class BaseModelIO: """Base class for model I/O operations.""" - def __init__( - self, - model: Optional[nn.Module] = None, - tokenizer: Optional[BpeTokenizer] = None, - config: Optional[ModelConfig] = None - ): - self.model = model - self.tokenizer = tokenizer or BpeTokenizer() - self.config = config or ModelConfig() + model: Optional[nn.Module] = field( + default=None, + metadata={"help": "Transformer model."} + ) + tokenizer: BpeTokenizer = field( + default_factory=BpeTokenizer, + metadata={"help": "Tokenizer for the model."} + ) + config: ModelConfig = field( + default_factory=ModelConfig, + metadata={"help": "Transformer model configuration."} + ) def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]: """Get standardized file paths for model components.""" @@ -71,19 +74,6 @@ class BaseModelIO: class ModelParameter(BaseModelIO): """Container for model parameters with serialization capabilities.""" - model: Optional[nn.Module] = field( - default=None, - metadata={"help": "Transformer model."} - ) - tokenizer: BpeTokenizer = field( - default_factory=BpeTokenizer, - metadata={"help": "Tokenizer for the model."} - ) - config: ModelConfig = field( - default_factory=ModelConfig, - metadata={"help": "Transformer model configuration."} - ) - def save(self, save_dir: Union[str, Path]): self.save_components(save_dir) @@ -95,18 +85,6 @@ class ModelParameter(BaseModelIO): class Checkpoint(BaseModelIO): """Extended model parameters with training state.""" - model: Optional[nn.Module] = field( - default=None, - metadata={"help": "Transformer model."} - ) - tokenizer: BpeTokenizer = field( - default_factory=BpeTokenizer, - metadata={"help": "Tokenizer for the model."} - ) - config: ModelConfig = field( - default_factory=ModelConfig, - metadata={"help": "Transformer model configuration."} - ) optimizer_state: Dict[str, Any] = field( default=None, metadata={"help": "Optimizer state."} @@ -129,50 +107,46 @@ class Checkpoint(BaseModelIO): ) def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]: - paths = self._get_file_paths(directory) - paths.update({ - "loss_list": paths["model"].parent / "loss.pkl", - "loss_plot": paths["model"].parent / "loss.png", - "optimizer_state": paths["model"].parent / "optimizer_state.pkl", - "sampler_state": paths["model"].parent / "sampler_state.pkl" - }) - return paths + dir_path = Path(directory) + return { + "loss_plot": dir_path / "loss_plot.png", + "training_state": dir_path / "training_state.pkl" + } + def to_dict(self) -> Dict[str, Any]: + return { + "optimizer_state": self.optimizer_state, + "scheduler_state": self.scheduler_state, + "epoch": self.epoch, + "batch_iter": self.batch_iter, + "loss_list": self.loss_list, + } + + def from_dict(self, data: Dict[str, Any]) -> Self: + self.optimizer_state = data["optimizer_state"] + self.scheduler_state = data["scheduler_state"] + self.epoch = data["epoch"] + self.batch_iter = data["batch_iter"] + self.loss_list = data["loss_list"] + def save_training_state(self, save_dir: Union[str, Path]): paths = self._get_training_paths(save_dir) # Save loss plot self._plot_loss(str(paths["loss_plot"])) - # Save loss list - with open(str(paths["loss_list"]), "wb") as f: - pkl.dump(self.loss_list, f) - - # Save optimizer state - with open(str(paths["optimizer_state"]), "wb") as f: - pkl.dump(self.optimizer_state, f) - - # Save sampler state - with open(str(paths["sampler_state"]), "wb") as f: - pkl.dump(self.scheduler_state, f) + # Save training state + with open(str(paths["training_state"]), "wb") as f: + pkl.dump(self.to_dict(), f) def load_training_state(self, load_dir: Union[str, Path]) -> Self: paths = self._get_training_paths(load_dir) - # Load loss list - if paths["loss_list"].exists(): - with open(str(paths["loss_list"]), "rb") as f: - self.loss_list = pkl.load(f) + # Load training state + with open(str(paths["training_state"]), "rb") as f: + train_state = pkl.load(f) - # Load optimizer state - if paths["optimizer_state"].exists(): - with open(str(paths["optimizer_state"]), "rb") as f: - self.optimizer_state = pkl.load(f) - - # Load sampler state - if paths["sampler_state"].exists(): - with open(str(paths["sampler_state"]), "rb") as f: - self.scheduler_state = pkl.load(f) + self.from_dict(train_state) return self @@ -189,7 +163,7 @@ class Checkpoint(BaseModelIO): plt.xlabel("Batch") plt.ylabel("Loss") plt.grid(True) - plt.savefig(save_path, dpi=300, bbox_inches="tight") + plt.savefig(save_path, dpi=30, bbox_inches="tight") plt.close() def save(self, save_dir: Union[str, Path]):