diff --git a/khaosz/trainer/__init__.py b/khaosz/trainer/__init__.py index f7aea3f..b4df66c 100644 --- a/khaosz/trainer/__init__.py +++ b/khaosz/trainer/__init__.py @@ -1,22 +1,21 @@ from khaosz.trainer.data_util import DatasetLoader from khaosz.trainer.trainer import Trainer +from khaosz.trainer.train_config import TrainConfig from khaosz.trainer.strategy import ( - TrainConfig, CosineScheduleConfig, SgdrScheduleConfig, StrategyFactory, SchedulerFactory ) -from khaosz.trainer.trainer_callback import ( - TrainerCallback, +from khaosz.trainer.train_callback import ( + TrainCallback, ProgressBarCallback, CheckpointCallback, - TrainerCallback, + TrainCallback, SchedulerCallback ) __all__ = [ - # strategy "DatasetLoader", "Trainer", "TrainConfig", @@ -26,9 +25,9 @@ __all__ = [ "SchedulerFactory", # callback - "TrainerCallback", + "TrainCallback", "ProgressBarCallback", "CheckpointCallback", - "TrainerCallback", + "TrainCallback", "SchedulerCallback", ] \ No newline at end of file diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 203bf43..9dc4dd4 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from torch import Tensor from torch.optim import Optimizer from torch.utils.data import Dataset -from typing import Any, Literal, Tuple, Callable, Dict +from typing import Any, Literal, Optional, Tuple, Callable, Dict from abc import ABC, abstractmethod from dataclasses import asdict, dataclass, field @@ -176,56 +176,7 @@ class StrategyFactory: } strategy = train_strategy[train_type]() return strategy - -@dataclass -class TrainConfig: - - strategy: BaseStrategy = field( - default=None, - metadata={"help": "Training strategy."} - ) - dataset: Dataset = field( - default=None, - metadata={"help": "Dataset for training."} - ) - optimizer: Optimizer = field( - default=None, - metadata={"help": "Optimizer for training."} - ) - checkpoint_dir: str = field( - default="./checkpoint", - metadata={"help": "Checkpoint directory."} - ) - n_epoch: int = field( - default=1, - metadata={"help": "Number of epochs for training."} - ) - batch_size: int = field( - default=4, - metadata={"help": "Batch size for training."} - ) - checkpoint_interval: int = field( - default=5000, - metadata={"help": "Number of iterations between checkpoints."} - ) - accumulation_steps: int = field( - default=1, - metadata={"help": "Number of iterations between steps."} - ) - max_grad_norm: float = field( - default=1.0, - metadata={"help": "Maximum gradient norm."} - ) - random_seed: int = field( - default=3407, - metadata={"help": "Random seed."} - ) - - def get_kwargs(self)-> Dict[str, Any]: - config_dict = asdict(self) - return {k: v for k, v in config_dict.items() if v is not None} - @dataclass class ScheduleConfig(ABC): diff --git a/khaosz/trainer/trainer_callback.py b/khaosz/trainer/train_callback.py similarity index 96% rename from khaosz/trainer/trainer_callback.py rename to khaosz/trainer/train_callback.py index de898fc..bf8d40f 100644 --- a/khaosz/trainer/trainer_callback.py +++ b/khaosz/trainer/train_callback.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from khaosz.trainer.trainer import Trainer -class TrainerCallback: +class TrainCallback: """ Callback interface for trainer. and we use '_' to ignore unused parameters. @@ -52,7 +52,7 @@ class TrainerCallback: _ = trainer, kwargs -class ProgressBarCallback(TrainerCallback): +class ProgressBarCallback(TrainCallback): """ Progress bar callback for trainer. """ @@ -84,7 +84,7 @@ class ProgressBarCallback(TrainerCallback): self.progress_bar.close() -class CheckpointCallback(TrainerCallback): +class CheckpointCallback(TrainCallback): """ Checkpoint callback for trainer. """ @@ -122,7 +122,7 @@ class CheckpointCallback(TrainerCallback): self.last_ckpt_iter = current_iter -class GradientClippingCallback(TrainerCallback): +class GradientClippingCallback(TrainCallback): """ Gradient clipping callback for trainer. """ @@ -134,7 +134,7 @@ class GradientClippingCallback(TrainerCallback): ) -class SchedulerCallback(TrainerCallback): +class SchedulerCallback(TrainCallback): """ Scheduler callback for trainer. """ diff --git a/khaosz/trainer/train_config.py b/khaosz/trainer/train_config.py new file mode 100644 index 0000000..e93e358 --- /dev/null +++ b/khaosz/trainer/train_config.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass, field +from typing import Optional +from torch.utils.data import Dataset +from torch.optim import Optimizer +from khaosz.trainer.strategy import BaseStrategy + + +@dataclass +class TrainConfig: + + strategy: BaseStrategy = field( + default=None, + metadata={"help": "Training strategy."} + ) + dataset: Dataset = field( + default=None, + metadata={"help": "Dataset for training."} + ) + optimizer: Optimizer = field( + default=None, + metadata={"help": "Optimizer for training."} + ) + checkpoint_dir: str = field( + default="./checkpoint", + metadata={"help": "Checkpoint directory."} + ) + n_epoch: int = field( + default=1, + metadata={"help": "Number of epochs for training."} + ) + batch_size: int = field( + default=4, + metadata={"help": "Batch size for training."} + ) + checkpoint_interval: int = field( + default=5000, + metadata={"help": "Number of iterations between checkpoints."} + ) + accumulation_steps: int = field( + default=1, + metadata={"help": "Number of iterations between steps."} + ) + max_grad_norm: float = field( + default=1.0, + metadata={"help": "Maximum gradient norm."} + ) + random_seed: int = field( + default=3407, + metadata={"help": "Random seed."} + ) + num_workers: int = field( + default=0, + metadata={"help": "Number of workers for dataloader."} + ) + prefetch_factor: Optional[int] = field( + default=None, + metadata={"help": "Prefetch factor for dataloader."} + ) + pin_memory: bool = field( + default=False, + metadata={"help": "Pin memory for dataloader."} + ) \ No newline at end of file diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index f2826b4..1324ccb 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field, fields from typing import Optional, Self, TYPE_CHECKING from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -11,13 +11,17 @@ if TYPE_CHECKING: @dataclass class TrainContext: - dataloader: DataLoader - optimizer: Optimizer - sampler: RandomSampler - epoch: int - current_iter: int - loss: float - checkpoint: Checkpoint + dataloader: DataLoader = field(default=None) + optimizer: Optimizer = field(default=None) + sampler: RandomSampler = field(default=None) + epoch: int = field(default=0) + current_iter: int = field(default=0) + loss: float = field(default=0.0) + checkpoint: Checkpoint = field(default=None) + + def asdict(self) -> dict: + return {field.name: getattr(self, field.name) + for field in fields(self)} class TrainContextBuilder: @@ -82,7 +86,10 @@ class TrainContextBuilder: dataloader = DataLoader( self.trainer.train_config.dataset, batch_size=self.trainer.train_config.batch_size, - sampler=self._context.sampler + sampler=self._context.sampler, + num_workers=self.trainer.train_config.num_workers, + pin_memory=self.trainer.train_config.pin_memory, + prefetch_factor=self.trainer.train_config.prefetch_factor ) self._context.dataloader = dataloader return self diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index da9a88c..0228459 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -2,9 +2,10 @@ import logging from typing import Optional, List from khaosz.core import ModelParameter, Checkpoint -from khaosz.trainer.strategy import TrainConfig, ScheduleConfig -from khaosz.trainer.trainer_callback import ( - TrainerCallback, +from khaosz.trainer.strategy import ScheduleConfig +from khaosz.trainer.train_config import TrainConfig +from khaosz.trainer.train_callback import ( + TrainCallback, ProgressBarCallback, CheckpointCallback, GradientClippingCallback, @@ -20,14 +21,14 @@ class Trainer: parameter: ModelParameter, train_config: TrainConfig, schedule_config: ScheduleConfig, - callbacks: Optional[List[TrainerCallback]] = None + callbacks: Optional[List[TrainCallback]] = None ): self.parameter = parameter self.train_config = train_config self.schedule_config = schedule_config self.callbacks = callbacks or self._get_default_callbacks() - def _get_default_callbacks(self) -> List[TrainerCallback]: + def _get_default_callbacks(self) -> List[TrainCallback]: return [ ProgressBarCallback(), CheckpointCallback(self.train_config.checkpoint_interval), @@ -44,16 +45,7 @@ class Trainer: .build()) def _call_callbacks(self, method_name: str, context: TrainContext): - kwargs = { - 'dataloader': context.dataloader, - 'optimizer': context.optimizer, - 'sampler': context.sampler, - 'epoch': context.epoch, - 'current_iter': context.current_iter, - 'loss': context.loss, - 'checkpoint': context.checkpoint - } - + kwargs = context.asdict() for callback in self.callbacks: method = getattr(callback, method_name, None) if method: