refactor(khaosz/trainer): 重构训练器模块结构以提升可维护性
This commit is contained in:
parent
e7d29ca2d5
commit
2ccd7bd583
|
|
@ -1,22 +1,21 @@
|
|||
from khaosz.trainer.data_util import DatasetLoader
|
||||
from khaosz.trainer.trainer import Trainer
|
||||
from khaosz.trainer.train_config import TrainConfig
|
||||
from khaosz.trainer.strategy import (
|
||||
TrainConfig,
|
||||
CosineScheduleConfig,
|
||||
SgdrScheduleConfig,
|
||||
StrategyFactory,
|
||||
SchedulerFactory
|
||||
)
|
||||
from khaosz.trainer.trainer_callback import (
|
||||
TrainerCallback,
|
||||
from khaosz.trainer.train_callback import (
|
||||
TrainCallback,
|
||||
ProgressBarCallback,
|
||||
CheckpointCallback,
|
||||
TrainerCallback,
|
||||
TrainCallback,
|
||||
SchedulerCallback
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# strategy
|
||||
"DatasetLoader",
|
||||
"Trainer",
|
||||
"TrainConfig",
|
||||
|
|
@ -26,9 +25,9 @@ __all__ = [
|
|||
"SchedulerFactory",
|
||||
|
||||
# callback
|
||||
"TrainerCallback",
|
||||
"TrainCallback",
|
||||
"ProgressBarCallback",
|
||||
"CheckpointCallback",
|
||||
"TrainerCallback",
|
||||
"TrainCallback",
|
||||
"SchedulerCallback",
|
||||
]
|
||||
|
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import Dataset
|
||||
from typing import Any, Literal, Tuple, Callable, Dict
|
||||
from typing import Any, Literal, Optional, Tuple, Callable, Dict
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
|
|
@ -178,55 +178,6 @@ class StrategyFactory:
|
|||
return strategy
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
|
||||
strategy: BaseStrategy = field(
|
||||
default=None,
|
||||
metadata={"help": "Training strategy."}
|
||||
)
|
||||
dataset: Dataset = field(
|
||||
default=None,
|
||||
metadata={"help": "Dataset for training."}
|
||||
)
|
||||
optimizer: Optimizer = field(
|
||||
default=None,
|
||||
metadata={"help": "Optimizer for training."}
|
||||
)
|
||||
checkpoint_dir: str = field(
|
||||
default="./checkpoint",
|
||||
metadata={"help": "Checkpoint directory."}
|
||||
)
|
||||
n_epoch: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of epochs for training."}
|
||||
)
|
||||
batch_size: int = field(
|
||||
default=4,
|
||||
metadata={"help": "Batch size for training."}
|
||||
)
|
||||
checkpoint_interval: int = field(
|
||||
default=5000,
|
||||
metadata={"help": "Number of iterations between checkpoints."}
|
||||
)
|
||||
accumulation_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of iterations between steps."}
|
||||
)
|
||||
max_grad_norm: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Maximum gradient norm."}
|
||||
)
|
||||
random_seed: int = field(
|
||||
default=3407,
|
||||
metadata={"help": "Random seed."}
|
||||
)
|
||||
|
||||
def get_kwargs(self)-> Dict[str, Any]:
|
||||
config_dict = asdict(self)
|
||||
return {k: v for k, v in config_dict.items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScheduleConfig(ABC):
|
||||
schedule_type: str = field(
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
|||
from khaosz.trainer.trainer import Trainer
|
||||
|
||||
|
||||
class TrainerCallback:
|
||||
class TrainCallback:
|
||||
"""
|
||||
Callback interface for trainer.
|
||||
and we use '_' to ignore unused parameters.
|
||||
|
|
@ -52,7 +52,7 @@ class TrainerCallback:
|
|||
_ = trainer, kwargs
|
||||
|
||||
|
||||
class ProgressBarCallback(TrainerCallback):
|
||||
class ProgressBarCallback(TrainCallback):
|
||||
"""
|
||||
Progress bar callback for trainer.
|
||||
"""
|
||||
|
|
@ -84,7 +84,7 @@ class ProgressBarCallback(TrainerCallback):
|
|||
self.progress_bar.close()
|
||||
|
||||
|
||||
class CheckpointCallback(TrainerCallback):
|
||||
class CheckpointCallback(TrainCallback):
|
||||
"""
|
||||
Checkpoint callback for trainer.
|
||||
"""
|
||||
|
|
@ -122,7 +122,7 @@ class CheckpointCallback(TrainerCallback):
|
|||
self.last_ckpt_iter = current_iter
|
||||
|
||||
|
||||
class GradientClippingCallback(TrainerCallback):
|
||||
class GradientClippingCallback(TrainCallback):
|
||||
"""
|
||||
Gradient clipping callback for trainer.
|
||||
"""
|
||||
|
|
@ -134,7 +134,7 @@ class GradientClippingCallback(TrainerCallback):
|
|||
)
|
||||
|
||||
|
||||
class SchedulerCallback(TrainerCallback):
|
||||
class SchedulerCallback(TrainCallback):
|
||||
"""
|
||||
Scheduler callback for trainer.
|
||||
"""
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from torch.utils.data import Dataset
|
||||
from torch.optim import Optimizer
|
||||
from khaosz.trainer.strategy import BaseStrategy
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
|
||||
strategy: BaseStrategy = field(
|
||||
default=None,
|
||||
metadata={"help": "Training strategy."}
|
||||
)
|
||||
dataset: Dataset = field(
|
||||
default=None,
|
||||
metadata={"help": "Dataset for training."}
|
||||
)
|
||||
optimizer: Optimizer = field(
|
||||
default=None,
|
||||
metadata={"help": "Optimizer for training."}
|
||||
)
|
||||
checkpoint_dir: str = field(
|
||||
default="./checkpoint",
|
||||
metadata={"help": "Checkpoint directory."}
|
||||
)
|
||||
n_epoch: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of epochs for training."}
|
||||
)
|
||||
batch_size: int = field(
|
||||
default=4,
|
||||
metadata={"help": "Batch size for training."}
|
||||
)
|
||||
checkpoint_interval: int = field(
|
||||
default=5000,
|
||||
metadata={"help": "Number of iterations between checkpoints."}
|
||||
)
|
||||
accumulation_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of iterations between steps."}
|
||||
)
|
||||
max_grad_norm: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Maximum gradient norm."}
|
||||
)
|
||||
random_seed: int = field(
|
||||
default=3407,
|
||||
metadata={"help": "Random seed."}
|
||||
)
|
||||
num_workers: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Number of workers for dataloader."}
|
||||
)
|
||||
prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Prefetch factor for dataloader."}
|
||||
)
|
||||
pin_memory: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Pin memory for dataloader."}
|
||||
)
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Optional, Self, TYPE_CHECKING
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
|
@ -11,13 +11,17 @@ if TYPE_CHECKING:
|
|||
|
||||
@dataclass
|
||||
class TrainContext:
|
||||
dataloader: DataLoader
|
||||
optimizer: Optimizer
|
||||
sampler: RandomSampler
|
||||
epoch: int
|
||||
current_iter: int
|
||||
loss: float
|
||||
checkpoint: Checkpoint
|
||||
dataloader: DataLoader = field(default=None)
|
||||
optimizer: Optimizer = field(default=None)
|
||||
sampler: RandomSampler = field(default=None)
|
||||
epoch: int = field(default=0)
|
||||
current_iter: int = field(default=0)
|
||||
loss: float = field(default=0.0)
|
||||
checkpoint: Checkpoint = field(default=None)
|
||||
|
||||
def asdict(self) -> dict:
|
||||
return {field.name: getattr(self, field.name)
|
||||
for field in fields(self)}
|
||||
|
||||
|
||||
class TrainContextBuilder:
|
||||
|
|
@ -82,7 +86,10 @@ class TrainContextBuilder:
|
|||
dataloader = DataLoader(
|
||||
self.trainer.train_config.dataset,
|
||||
batch_size=self.trainer.train_config.batch_size,
|
||||
sampler=self._context.sampler
|
||||
sampler=self._context.sampler,
|
||||
num_workers=self.trainer.train_config.num_workers,
|
||||
pin_memory=self.trainer.train_config.pin_memory,
|
||||
prefetch_factor=self.trainer.train_config.prefetch_factor
|
||||
)
|
||||
self._context.dataloader = dataloader
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -2,9 +2,10 @@ import logging
|
|||
from typing import Optional, List
|
||||
|
||||
from khaosz.core import ModelParameter, Checkpoint
|
||||
from khaosz.trainer.strategy import TrainConfig, ScheduleConfig
|
||||
from khaosz.trainer.trainer_callback import (
|
||||
TrainerCallback,
|
||||
from khaosz.trainer.strategy import ScheduleConfig
|
||||
from khaosz.trainer.train_config import TrainConfig
|
||||
from khaosz.trainer.train_callback import (
|
||||
TrainCallback,
|
||||
ProgressBarCallback,
|
||||
CheckpointCallback,
|
||||
GradientClippingCallback,
|
||||
|
|
@ -20,14 +21,14 @@ class Trainer:
|
|||
parameter: ModelParameter,
|
||||
train_config: TrainConfig,
|
||||
schedule_config: ScheduleConfig,
|
||||
callbacks: Optional[List[TrainerCallback]] = None
|
||||
callbacks: Optional[List[TrainCallback]] = None
|
||||
):
|
||||
self.parameter = parameter
|
||||
self.train_config = train_config
|
||||
self.schedule_config = schedule_config
|
||||
self.callbacks = callbacks or self._get_default_callbacks()
|
||||
|
||||
def _get_default_callbacks(self) -> List[TrainerCallback]:
|
||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||
return [
|
||||
ProgressBarCallback(),
|
||||
CheckpointCallback(self.train_config.checkpoint_interval),
|
||||
|
|
@ -44,16 +45,7 @@ class Trainer:
|
|||
.build())
|
||||
|
||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||
kwargs = {
|
||||
'dataloader': context.dataloader,
|
||||
'optimizer': context.optimizer,
|
||||
'sampler': context.sampler,
|
||||
'epoch': context.epoch,
|
||||
'current_iter': context.current_iter,
|
||||
'loss': context.loss,
|
||||
'checkpoint': context.checkpoint
|
||||
}
|
||||
|
||||
kwargs = context.asdict()
|
||||
for callback in self.callbacks:
|
||||
method = getattr(callback, method_name, None)
|
||||
if method:
|
||||
|
|
|
|||
Loading…
Reference in New Issue