119 lines
4.2 KiB
Python
119 lines
4.2 KiB
Python
import torch.nn as nn
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import LRScheduler
|
|
from torch.utils.data import DataLoader
|
|
|
|
from khaosz.data import ResumableDistributedSampler
|
|
from khaosz.trainer.checkpoint import Checkpoint
|
|
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
|
|
from khaosz.config.train_config import TrainConfig
|
|
from khaosz.parallel.utils import get_current_device, get_world_size, get_rank
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional, Self
|
|
|
|
|
|
@dataclass
|
|
class TrainContext:
|
|
model: nn.Module = field(default=None)
|
|
strategy: BaseStrategy = field(default=None)
|
|
dataloader: DataLoader = field(default=None)
|
|
optimizer: Optimizer = field(default=None)
|
|
scheduler: LRScheduler = field(default=None)
|
|
checkpoint: Checkpoint = field(default=None)
|
|
|
|
epoch: int = field(default=0)
|
|
iteration: int = field(default=0)
|
|
loss: float = field(default=0.0)
|
|
|
|
wolrd_size: int = field(default=1)
|
|
rank: int = field(default=0)
|
|
|
|
|
|
class TrainContextBuilder:
|
|
def __init__(self, config: TrainConfig):
|
|
self.config = config
|
|
self._context: TrainContext = None
|
|
|
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
|
self._context = TrainContext()
|
|
if checkpoint is None:
|
|
checkpoint = Checkpoint(
|
|
optimizer_state=self.config.optimizer.state_dict(),
|
|
scheduler_state=self.config.scheduler.state_dict(),
|
|
)
|
|
else:
|
|
# resume from the assigned checkpoint or assigned iteration
|
|
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
|
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
|
|
|
self._context.checkpoint = checkpoint
|
|
return self
|
|
|
|
def with_optimizer(self) -> Self:
|
|
optimizer = self.config.optimizer
|
|
if self._context is None:
|
|
raise RuntimeError("Must call with_checkpoint() before with_optimizer()")
|
|
|
|
if self._context.checkpoint and self._context.checkpoint.optimizer_state:
|
|
optimizer.load_state_dict(self._context.checkpoint.optimizer_state)
|
|
|
|
self._context.optimizer = optimizer
|
|
|
|
if self._context.checkpoint:
|
|
self._context.checkpoint.optimizer_state = optimizer.state_dict()
|
|
|
|
return self
|
|
|
|
def with_scheduler(self) -> Self:
|
|
scheduler = self.config.scheduler
|
|
if self._context.checkpoint and self._context.checkpoint.scheduler_state:
|
|
scheduler.load_state_dict(self._context.checkpoint.scheduler_state)
|
|
|
|
self._context.scheduler = scheduler
|
|
|
|
if self._context.checkpoint:
|
|
self._context.checkpoint.scheduler_state = scheduler.state_dict()
|
|
|
|
return self
|
|
|
|
def with_dataloader(self) -> Self:
|
|
# fix: change batch level iteration to sample level offset
|
|
config = self.config
|
|
sampler_offset = self._context.iteration * config.batch_size
|
|
resumeable_sampler = ResumableDistributedSampler(
|
|
data_source=config.dataset,
|
|
start_epoch=self._context.epoch,
|
|
start_iter=sampler_offset,
|
|
seed=config.random_seed
|
|
)
|
|
|
|
dataloader = DataLoader(
|
|
config.dataset,
|
|
batch_size=config.batch_size,
|
|
sampler=resumeable_sampler,
|
|
num_workers=config.num_workers,
|
|
pin_memory=config.pin_memory,
|
|
prefetch_factor=config.prefetch_factor
|
|
)
|
|
self._context.dataloader = dataloader
|
|
return self
|
|
|
|
def with_strategy(self) -> Self:
|
|
device = get_current_device()
|
|
self._context.strategy = StrategyFactory.load(
|
|
model=self.config.model,
|
|
train_type=self.config.strategy,
|
|
device=device,
|
|
**self.config.kwargs
|
|
)
|
|
return self
|
|
|
|
def build(self) -> TrainContext:
|
|
self._context.model = self.config.model
|
|
|
|
if self.config.nprocs > 1:
|
|
self._context.wolrd_size = get_world_size()
|
|
self._context.rank = get_rank()
|
|
|
|
return self._context |