refactor(khaosz/trainer): 重构训练器模块结构以提升可维护性
This commit is contained in:
parent
e7d29ca2d5
commit
2ccd7bd583
|
|
@ -1,22 +1,21 @@
|
||||||
from khaosz.trainer.data_util import DatasetLoader
|
from khaosz.trainer.data_util import DatasetLoader
|
||||||
from khaosz.trainer.trainer import Trainer
|
from khaosz.trainer.trainer import Trainer
|
||||||
|
from khaosz.trainer.train_config import TrainConfig
|
||||||
from khaosz.trainer.strategy import (
|
from khaosz.trainer.strategy import (
|
||||||
TrainConfig,
|
|
||||||
CosineScheduleConfig,
|
CosineScheduleConfig,
|
||||||
SgdrScheduleConfig,
|
SgdrScheduleConfig,
|
||||||
StrategyFactory,
|
StrategyFactory,
|
||||||
SchedulerFactory
|
SchedulerFactory
|
||||||
)
|
)
|
||||||
from khaosz.trainer.trainer_callback import (
|
from khaosz.trainer.train_callback import (
|
||||||
TrainerCallback,
|
TrainCallback,
|
||||||
ProgressBarCallback,
|
ProgressBarCallback,
|
||||||
CheckpointCallback,
|
CheckpointCallback,
|
||||||
TrainerCallback,
|
TrainCallback,
|
||||||
SchedulerCallback
|
SchedulerCallback
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# strategy
|
|
||||||
"DatasetLoader",
|
"DatasetLoader",
|
||||||
"Trainer",
|
"Trainer",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
|
|
@ -26,9 +25,9 @@ __all__ = [
|
||||||
"SchedulerFactory",
|
"SchedulerFactory",
|
||||||
|
|
||||||
# callback
|
# callback
|
||||||
"TrainerCallback",
|
"TrainCallback",
|
||||||
"ProgressBarCallback",
|
"ProgressBarCallback",
|
||||||
"CheckpointCallback",
|
"CheckpointCallback",
|
||||||
"TrainerCallback",
|
"TrainCallback",
|
||||||
"SchedulerCallback",
|
"SchedulerCallback",
|
||||||
]
|
]
|
||||||
|
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from typing import Any, Literal, Tuple, Callable, Dict
|
from typing import Any, Literal, Optional, Tuple, Callable, Dict
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
|
||||||
|
|
@ -176,56 +176,7 @@ class StrategyFactory:
|
||||||
}
|
}
|
||||||
strategy = train_strategy[train_type]()
|
strategy = train_strategy[train_type]()
|
||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrainConfig:
|
|
||||||
|
|
||||||
strategy: BaseStrategy = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Training strategy."}
|
|
||||||
)
|
|
||||||
dataset: Dataset = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Dataset for training."}
|
|
||||||
)
|
|
||||||
optimizer: Optimizer = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Optimizer for training."}
|
|
||||||
)
|
|
||||||
checkpoint_dir: str = field(
|
|
||||||
default="./checkpoint",
|
|
||||||
metadata={"help": "Checkpoint directory."}
|
|
||||||
)
|
|
||||||
n_epoch: int = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "Number of epochs for training."}
|
|
||||||
)
|
|
||||||
batch_size: int = field(
|
|
||||||
default=4,
|
|
||||||
metadata={"help": "Batch size for training."}
|
|
||||||
)
|
|
||||||
checkpoint_interval: int = field(
|
|
||||||
default=5000,
|
|
||||||
metadata={"help": "Number of iterations between checkpoints."}
|
|
||||||
)
|
|
||||||
accumulation_steps: int = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "Number of iterations between steps."}
|
|
||||||
)
|
|
||||||
max_grad_norm: float = field(
|
|
||||||
default=1.0,
|
|
||||||
metadata={"help": "Maximum gradient norm."}
|
|
||||||
)
|
|
||||||
random_seed: int = field(
|
|
||||||
default=3407,
|
|
||||||
metadata={"help": "Random seed."}
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_kwargs(self)-> Dict[str, Any]:
|
|
||||||
config_dict = asdict(self)
|
|
||||||
return {k: v for k, v in config_dict.items() if v is not None}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ScheduleConfig(ABC):
|
class ScheduleConfig(ABC):
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
||||||
from khaosz.trainer.trainer import Trainer
|
from khaosz.trainer.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
class TrainerCallback:
|
class TrainCallback:
|
||||||
"""
|
"""
|
||||||
Callback interface for trainer.
|
Callback interface for trainer.
|
||||||
and we use '_' to ignore unused parameters.
|
and we use '_' to ignore unused parameters.
|
||||||
|
|
@ -52,7 +52,7 @@ class TrainerCallback:
|
||||||
_ = trainer, kwargs
|
_ = trainer, kwargs
|
||||||
|
|
||||||
|
|
||||||
class ProgressBarCallback(TrainerCallback):
|
class ProgressBarCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Progress bar callback for trainer.
|
Progress bar callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
@ -84,7 +84,7 @@ class ProgressBarCallback(TrainerCallback):
|
||||||
self.progress_bar.close()
|
self.progress_bar.close()
|
||||||
|
|
||||||
|
|
||||||
class CheckpointCallback(TrainerCallback):
|
class CheckpointCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Checkpoint callback for trainer.
|
Checkpoint callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
@ -122,7 +122,7 @@ class CheckpointCallback(TrainerCallback):
|
||||||
self.last_ckpt_iter = current_iter
|
self.last_ckpt_iter = current_iter
|
||||||
|
|
||||||
|
|
||||||
class GradientClippingCallback(TrainerCallback):
|
class GradientClippingCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Gradient clipping callback for trainer.
|
Gradient clipping callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
@ -134,7 +134,7 @@ class GradientClippingCallback(TrainerCallback):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SchedulerCallback(TrainerCallback):
|
class SchedulerCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Scheduler callback for trainer.
|
Scheduler callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from khaosz.trainer.strategy import BaseStrategy
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainConfig:
|
||||||
|
|
||||||
|
strategy: BaseStrategy = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Training strategy."}
|
||||||
|
)
|
||||||
|
dataset: Dataset = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Dataset for training."}
|
||||||
|
)
|
||||||
|
optimizer: Optimizer = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Optimizer for training."}
|
||||||
|
)
|
||||||
|
checkpoint_dir: str = field(
|
||||||
|
default="./checkpoint",
|
||||||
|
metadata={"help": "Checkpoint directory."}
|
||||||
|
)
|
||||||
|
n_epoch: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of epochs for training."}
|
||||||
|
)
|
||||||
|
batch_size: int = field(
|
||||||
|
default=4,
|
||||||
|
metadata={"help": "Batch size for training."}
|
||||||
|
)
|
||||||
|
checkpoint_interval: int = field(
|
||||||
|
default=5000,
|
||||||
|
metadata={"help": "Number of iterations between checkpoints."}
|
||||||
|
)
|
||||||
|
accumulation_steps: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of iterations between steps."}
|
||||||
|
)
|
||||||
|
max_grad_norm: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "Maximum gradient norm."}
|
||||||
|
)
|
||||||
|
random_seed: int = field(
|
||||||
|
default=3407,
|
||||||
|
metadata={"help": "Random seed."}
|
||||||
|
)
|
||||||
|
num_workers: int = field(
|
||||||
|
default=0,
|
||||||
|
metadata={"help": "Number of workers for dataloader."}
|
||||||
|
)
|
||||||
|
prefetch_factor: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Prefetch factor for dataloader."}
|
||||||
|
)
|
||||||
|
pin_memory: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Pin memory for dataloader."}
|
||||||
|
)
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Optional, Self, TYPE_CHECKING
|
from typing import Optional, Self, TYPE_CHECKING
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
@ -11,13 +11,17 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainContext:
|
class TrainContext:
|
||||||
dataloader: DataLoader
|
dataloader: DataLoader = field(default=None)
|
||||||
optimizer: Optimizer
|
optimizer: Optimizer = field(default=None)
|
||||||
sampler: RandomSampler
|
sampler: RandomSampler = field(default=None)
|
||||||
epoch: int
|
epoch: int = field(default=0)
|
||||||
current_iter: int
|
current_iter: int = field(default=0)
|
||||||
loss: float
|
loss: float = field(default=0.0)
|
||||||
checkpoint: Checkpoint
|
checkpoint: Checkpoint = field(default=None)
|
||||||
|
|
||||||
|
def asdict(self) -> dict:
|
||||||
|
return {field.name: getattr(self, field.name)
|
||||||
|
for field in fields(self)}
|
||||||
|
|
||||||
|
|
||||||
class TrainContextBuilder:
|
class TrainContextBuilder:
|
||||||
|
|
@ -82,7 +86,10 @@ class TrainContextBuilder:
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
self.trainer.train_config.dataset,
|
self.trainer.train_config.dataset,
|
||||||
batch_size=self.trainer.train_config.batch_size,
|
batch_size=self.trainer.train_config.batch_size,
|
||||||
sampler=self._context.sampler
|
sampler=self._context.sampler,
|
||||||
|
num_workers=self.trainer.train_config.num_workers,
|
||||||
|
pin_memory=self.trainer.train_config.pin_memory,
|
||||||
|
prefetch_factor=self.trainer.train_config.prefetch_factor
|
||||||
)
|
)
|
||||||
self._context.dataloader = dataloader
|
self._context.dataloader = dataloader
|
||||||
return self
|
return self
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,10 @@ import logging
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from khaosz.core import ModelParameter, Checkpoint
|
from khaosz.core import ModelParameter, Checkpoint
|
||||||
from khaosz.trainer.strategy import TrainConfig, ScheduleConfig
|
from khaosz.trainer.strategy import ScheduleConfig
|
||||||
from khaosz.trainer.trainer_callback import (
|
from khaosz.trainer.train_config import TrainConfig
|
||||||
TrainerCallback,
|
from khaosz.trainer.train_callback import (
|
||||||
|
TrainCallback,
|
||||||
ProgressBarCallback,
|
ProgressBarCallback,
|
||||||
CheckpointCallback,
|
CheckpointCallback,
|
||||||
GradientClippingCallback,
|
GradientClippingCallback,
|
||||||
|
|
@ -20,14 +21,14 @@ class Trainer:
|
||||||
parameter: ModelParameter,
|
parameter: ModelParameter,
|
||||||
train_config: TrainConfig,
|
train_config: TrainConfig,
|
||||||
schedule_config: ScheduleConfig,
|
schedule_config: ScheduleConfig,
|
||||||
callbacks: Optional[List[TrainerCallback]] = None
|
callbacks: Optional[List[TrainCallback]] = None
|
||||||
):
|
):
|
||||||
self.parameter = parameter
|
self.parameter = parameter
|
||||||
self.train_config = train_config
|
self.train_config = train_config
|
||||||
self.schedule_config = schedule_config
|
self.schedule_config = schedule_config
|
||||||
self.callbacks = callbacks or self._get_default_callbacks()
|
self.callbacks = callbacks or self._get_default_callbacks()
|
||||||
|
|
||||||
def _get_default_callbacks(self) -> List[TrainerCallback]:
|
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||||
return [
|
return [
|
||||||
ProgressBarCallback(),
|
ProgressBarCallback(),
|
||||||
CheckpointCallback(self.train_config.checkpoint_interval),
|
CheckpointCallback(self.train_config.checkpoint_interval),
|
||||||
|
|
@ -44,16 +45,7 @@ class Trainer:
|
||||||
.build())
|
.build())
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
kwargs = {
|
kwargs = context.asdict()
|
||||||
'dataloader': context.dataloader,
|
|
||||||
'optimizer': context.optimizer,
|
|
||||||
'sampler': context.sampler,
|
|
||||||
'epoch': context.epoch,
|
|
||||||
'current_iter': context.current_iter,
|
|
||||||
'loss': context.loss,
|
|
||||||
'checkpoint': context.checkpoint
|
|
||||||
}
|
|
||||||
|
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
method = getattr(callback, method_name, None)
|
method = getattr(callback, method_name, None)
|
||||||
if method:
|
if method:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue