fix(trainer): 修复训练上下文构建逻辑并修正拼写错误

This commit is contained in:
ViperEkura 2025-12-10 15:02:39 +08:00
parent 530fb50352
commit 110efd2a21
2 changed files with 10 additions and 38 deletions

View File

@ -26,17 +26,22 @@ class TrainContext:
iteration: int = field(default=0) iteration: int = field(default=0)
loss: float = field(default=0.0) loss: float = field(default=0.0)
wolrd_size: int = field(default=1) world_size: int = field(default=1)
rank: int = field(default=0) rank: int = field(default=0)
class TrainContextBuilder: class TrainContextBuilder:
def __init__(self, config: TrainConfig): def __init__(self, config: TrainConfig):
self.config = config self.config = config
self._context: TrainContext = None self._context = TrainContext(
model=config.model,
optimizer=config.optimizer,
scheduler=config.scheduler,
world_size=get_world_size(),
rank=get_rank(),
)
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._context = TrainContext()
if checkpoint is None: if checkpoint is None:
checkpoint = Checkpoint( checkpoint = Checkpoint(
optimizer_state=self.config.optimizer.state_dict(), optimizer_state=self.config.optimizer.state_dict(),
@ -46,37 +51,12 @@ class TrainContextBuilder:
# resume from the assigned checkpoint or assigned iteration # resume from the assigned checkpoint or assigned iteration
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch) self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
self._context.iteration = max(checkpoint.iteration, self.config.start_batch) self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
self._context.optimizer.load_state_dict(checkpoint.optimizer_state)
self._context.scheduler.load_state_dict(checkpoint.scheduler_state)
self._context.checkpoint = checkpoint self._context.checkpoint = checkpoint
return self return self
def with_optimizer(self) -> Self:
optimizer = self.config.optimizer
if self._context is None:
raise RuntimeError("Must call with_checkpoint() before with_optimizer()")
if self._context.checkpoint and self._context.checkpoint.optimizer_state:
optimizer.load_state_dict(self._context.checkpoint.optimizer_state)
self._context.optimizer = optimizer
if self._context.checkpoint:
self._context.checkpoint.optimizer_state = optimizer.state_dict()
return self
def with_scheduler(self) -> Self:
scheduler = self.config.scheduler
if self._context.checkpoint and self._context.checkpoint.scheduler_state:
scheduler.load_state_dict(self._context.checkpoint.scheduler_state)
self._context.scheduler = scheduler
if self._context.checkpoint:
self._context.checkpoint.scheduler_state = scheduler.state_dict()
return self
def with_dataloader(self) -> Self: def with_dataloader(self) -> Self:
# fix: change batch level iteration to sample level offset # fix: change batch level iteration to sample level offset
config = self.config config = self.config
@ -110,10 +90,4 @@ class TrainContextBuilder:
return self return self
def build(self) -> TrainContext: def build(self) -> TrainContext:
self._context.model = self.config.model
if self.config.nprocs > 1:
self._context.wolrd_size = get_world_size()
self._context.rank = get_rank()
return self._context return self._context

View File

@ -34,8 +34,6 @@ class Trainer:
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (TrainContextBuilder(self.train_config) return (TrainContextBuilder(self.train_config)
.with_checkpoint(checkpoint) .with_checkpoint(checkpoint)
.with_optimizer()
.with_scheduler()
.with_dataloader() .with_dataloader()
.with_strategy() .with_strategy()
.build()) .build())