fix(trainer): 修复训练上下文构建逻辑并修正拼写错误

This commit is contained in:
ViperEkura 2025-12-10 15:02:39 +08:00
parent 530fb50352
commit 110efd2a21
2 changed files with 10 additions and 38 deletions

View File

@ -26,17 +26,22 @@ class TrainContext:
iteration: int = field(default=0)
loss: float = field(default=0.0)
wolrd_size: int = field(default=1)
world_size: int = field(default=1)
rank: int = field(default=0)
class TrainContextBuilder:
def __init__(self, config: TrainConfig):
self.config = config
self._context: TrainContext = None
self._context = TrainContext(
model=config.model,
optimizer=config.optimizer,
scheduler=config.scheduler,
world_size=get_world_size(),
rank=get_rank(),
)
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._context = TrainContext()
if checkpoint is None:
checkpoint = Checkpoint(
optimizer_state=self.config.optimizer.state_dict(),
@ -46,37 +51,12 @@ class TrainContextBuilder:
# resume from the assigned checkpoint or assigned iteration
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
self._context.optimizer.load_state_dict(checkpoint.optimizer_state)
self._context.scheduler.load_state_dict(checkpoint.scheduler_state)
self._context.checkpoint = checkpoint
return self
def with_optimizer(self) -> Self:
optimizer = self.config.optimizer
if self._context is None:
raise RuntimeError("Must call with_checkpoint() before with_optimizer()")
if self._context.checkpoint and self._context.checkpoint.optimizer_state:
optimizer.load_state_dict(self._context.checkpoint.optimizer_state)
self._context.optimizer = optimizer
if self._context.checkpoint:
self._context.checkpoint.optimizer_state = optimizer.state_dict()
return self
def with_scheduler(self) -> Self:
scheduler = self.config.scheduler
if self._context.checkpoint and self._context.checkpoint.scheduler_state:
scheduler.load_state_dict(self._context.checkpoint.scheduler_state)
self._context.scheduler = scheduler
if self._context.checkpoint:
self._context.checkpoint.scheduler_state = scheduler.state_dict()
return self
def with_dataloader(self) -> Self:
# fix: change batch level iteration to sample level offset
config = self.config
@ -110,10 +90,4 @@ class TrainContextBuilder:
return self
def build(self) -> TrainContext:
self._context.model = self.config.model
if self.config.nprocs > 1:
self._context.wolrd_size = get_world_size()
self._context.rank = get_rank()
return self._context

View File

@ -34,8 +34,6 @@ class Trainer:
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (TrainContextBuilder(self.train_config)
.with_checkpoint(checkpoint)
.with_optimizer()
.with_scheduler()
.with_dataloader()
.with_strategy()
.build())