refactor(khaosz/trainer): 重构训练器模块结构以提升可维护性

This commit is contained in:
ViperEkura 2025-10-04 21:31:15 +08:00
parent e7d29ca2d5
commit 2ccd7bd583
6 changed files with 97 additions and 86 deletions

View File

@ -1,22 +1,21 @@
from khaosz.trainer.data_util import DatasetLoader from khaosz.trainer.data_util import DatasetLoader
from khaosz.trainer.trainer import Trainer from khaosz.trainer.trainer import Trainer
from khaosz.trainer.train_config import TrainConfig
from khaosz.trainer.strategy import ( from khaosz.trainer.strategy import (
TrainConfig,
CosineScheduleConfig, CosineScheduleConfig,
SgdrScheduleConfig, SgdrScheduleConfig,
StrategyFactory, StrategyFactory,
SchedulerFactory SchedulerFactory
) )
from khaosz.trainer.trainer_callback import ( from khaosz.trainer.train_callback import (
TrainerCallback, TrainCallback,
ProgressBarCallback, ProgressBarCallback,
CheckpointCallback, CheckpointCallback,
TrainerCallback, TrainCallback,
SchedulerCallback SchedulerCallback
) )
__all__ = [ __all__ = [
# strategy
"DatasetLoader", "DatasetLoader",
"Trainer", "Trainer",
"TrainConfig", "TrainConfig",
@ -26,9 +25,9 @@ __all__ = [
"SchedulerFactory", "SchedulerFactory",
# callback # callback
"TrainerCallback", "TrainCallback",
"ProgressBarCallback", "ProgressBarCallback",
"CheckpointCallback", "CheckpointCallback",
"TrainerCallback", "TrainCallback",
"SchedulerCallback", "SchedulerCallback",
] ]

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import Dataset from torch.utils.data import Dataset
from typing import Any, Literal, Tuple, Callable, Dict from typing import Any, Literal, Optional, Tuple, Callable, Dict
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
@ -176,56 +176,7 @@ class StrategyFactory:
} }
strategy = train_strategy[train_type]() strategy = train_strategy[train_type]()
return strategy return strategy
@dataclass
class TrainConfig:
strategy: BaseStrategy = field(
default=None,
metadata={"help": "Training strategy."}
)
dataset: Dataset = field(
default=None,
metadata={"help": "Dataset for training."}
)
optimizer: Optimizer = field(
default=None,
metadata={"help": "Optimizer for training."}
)
checkpoint_dir: str = field(
default="./checkpoint",
metadata={"help": "Checkpoint directory."}
)
n_epoch: int = field(
default=1,
metadata={"help": "Number of epochs for training."}
)
batch_size: int = field(
default=4,
metadata={"help": "Batch size for training."}
)
checkpoint_interval: int = field(
default=5000,
metadata={"help": "Number of iterations between checkpoints."}
)
accumulation_steps: int = field(
default=1,
metadata={"help": "Number of iterations between steps."}
)
max_grad_norm: float = field(
default=1.0,
metadata={"help": "Maximum gradient norm."}
)
random_seed: int = field(
default=3407,
metadata={"help": "Random seed."}
)
def get_kwargs(self)-> Dict[str, Any]:
config_dict = asdict(self)
return {k: v for k, v in config_dict.items() if v is not None}
@dataclass @dataclass
class ScheduleConfig(ABC): class ScheduleConfig(ABC):

View File

@ -13,7 +13,7 @@ if TYPE_CHECKING:
from khaosz.trainer.trainer import Trainer from khaosz.trainer.trainer import Trainer
class TrainerCallback: class TrainCallback:
""" """
Callback interface for trainer. Callback interface for trainer.
and we use '_' to ignore unused parameters. and we use '_' to ignore unused parameters.
@ -52,7 +52,7 @@ class TrainerCallback:
_ = trainer, kwargs _ = trainer, kwargs
class ProgressBarCallback(TrainerCallback): class ProgressBarCallback(TrainCallback):
""" """
Progress bar callback for trainer. Progress bar callback for trainer.
""" """
@ -84,7 +84,7 @@ class ProgressBarCallback(TrainerCallback):
self.progress_bar.close() self.progress_bar.close()
class CheckpointCallback(TrainerCallback): class CheckpointCallback(TrainCallback):
""" """
Checkpoint callback for trainer. Checkpoint callback for trainer.
""" """
@ -122,7 +122,7 @@ class CheckpointCallback(TrainerCallback):
self.last_ckpt_iter = current_iter self.last_ckpt_iter = current_iter
class GradientClippingCallback(TrainerCallback): class GradientClippingCallback(TrainCallback):
""" """
Gradient clipping callback for trainer. Gradient clipping callback for trainer.
""" """
@ -134,7 +134,7 @@ class GradientClippingCallback(TrainerCallback):
) )
class SchedulerCallback(TrainerCallback): class SchedulerCallback(TrainCallback):
""" """
Scheduler callback for trainer. Scheduler callback for trainer.
""" """

View File

@ -0,0 +1,62 @@
from dataclasses import dataclass, field
from typing import Optional
from torch.utils.data import Dataset
from torch.optim import Optimizer
from khaosz.trainer.strategy import BaseStrategy
@dataclass
class TrainConfig:
strategy: BaseStrategy = field(
default=None,
metadata={"help": "Training strategy."}
)
dataset: Dataset = field(
default=None,
metadata={"help": "Dataset for training."}
)
optimizer: Optimizer = field(
default=None,
metadata={"help": "Optimizer for training."}
)
checkpoint_dir: str = field(
default="./checkpoint",
metadata={"help": "Checkpoint directory."}
)
n_epoch: int = field(
default=1,
metadata={"help": "Number of epochs for training."}
)
batch_size: int = field(
default=4,
metadata={"help": "Batch size for training."}
)
checkpoint_interval: int = field(
default=5000,
metadata={"help": "Number of iterations between checkpoints."}
)
accumulation_steps: int = field(
default=1,
metadata={"help": "Number of iterations between steps."}
)
max_grad_norm: float = field(
default=1.0,
metadata={"help": "Maximum gradient norm."}
)
random_seed: int = field(
default=3407,
metadata={"help": "Random seed."}
)
num_workers: int = field(
default=0,
metadata={"help": "Number of workers for dataloader."}
)
prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "Prefetch factor for dataloader."}
)
pin_memory: bool = field(
default=False,
metadata={"help": "Pin memory for dataloader."}
)

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass from dataclasses import dataclass, field, fields
from typing import Optional, Self, TYPE_CHECKING from typing import Optional, Self, TYPE_CHECKING
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -11,13 +11,17 @@ if TYPE_CHECKING:
@dataclass @dataclass
class TrainContext: class TrainContext:
dataloader: DataLoader dataloader: DataLoader = field(default=None)
optimizer: Optimizer optimizer: Optimizer = field(default=None)
sampler: RandomSampler sampler: RandomSampler = field(default=None)
epoch: int epoch: int = field(default=0)
current_iter: int current_iter: int = field(default=0)
loss: float loss: float = field(default=0.0)
checkpoint: Checkpoint checkpoint: Checkpoint = field(default=None)
def asdict(self) -> dict:
return {field.name: getattr(self, field.name)
for field in fields(self)}
class TrainContextBuilder: class TrainContextBuilder:
@ -82,7 +86,10 @@ class TrainContextBuilder:
dataloader = DataLoader( dataloader = DataLoader(
self.trainer.train_config.dataset, self.trainer.train_config.dataset,
batch_size=self.trainer.train_config.batch_size, batch_size=self.trainer.train_config.batch_size,
sampler=self._context.sampler sampler=self._context.sampler,
num_workers=self.trainer.train_config.num_workers,
pin_memory=self.trainer.train_config.pin_memory,
prefetch_factor=self.trainer.train_config.prefetch_factor
) )
self._context.dataloader = dataloader self._context.dataloader = dataloader
return self return self

View File

@ -2,9 +2,10 @@ import logging
from typing import Optional, List from typing import Optional, List
from khaosz.core import ModelParameter, Checkpoint from khaosz.core import ModelParameter, Checkpoint
from khaosz.trainer.strategy import TrainConfig, ScheduleConfig from khaosz.trainer.strategy import ScheduleConfig
from khaosz.trainer.trainer_callback import ( from khaosz.trainer.train_config import TrainConfig
TrainerCallback, from khaosz.trainer.train_callback import (
TrainCallback,
ProgressBarCallback, ProgressBarCallback,
CheckpointCallback, CheckpointCallback,
GradientClippingCallback, GradientClippingCallback,
@ -20,14 +21,14 @@ class Trainer:
parameter: ModelParameter, parameter: ModelParameter,
train_config: TrainConfig, train_config: TrainConfig,
schedule_config: ScheduleConfig, schedule_config: ScheduleConfig,
callbacks: Optional[List[TrainerCallback]] = None callbacks: Optional[List[TrainCallback]] = None
): ):
self.parameter = parameter self.parameter = parameter
self.train_config = train_config self.train_config = train_config
self.schedule_config = schedule_config self.schedule_config = schedule_config
self.callbacks = callbacks or self._get_default_callbacks() self.callbacks = callbacks or self._get_default_callbacks()
def _get_default_callbacks(self) -> List[TrainerCallback]: def _get_default_callbacks(self) -> List[TrainCallback]:
return [ return [
ProgressBarCallback(), ProgressBarCallback(),
CheckpointCallback(self.train_config.checkpoint_interval), CheckpointCallback(self.train_config.checkpoint_interval),
@ -44,16 +45,7 @@ class Trainer:
.build()) .build())
def _call_callbacks(self, method_name: str, context: TrainContext): def _call_callbacks(self, method_name: str, context: TrainContext):
kwargs = { kwargs = context.asdict()
'dataloader': context.dataloader,
'optimizer': context.optimizer,
'sampler': context.sampler,
'epoch': context.epoch,
'current_iter': context.current_iter,
'loss': context.loss,
'checkpoint': context.checkpoint
}
for callback in self.callbacks: for callback in self.callbacks:
method = getattr(callback, method_name, None) method = getattr(callback, method_name, None)
if method: if method: