fix(trainer): 修复训练上下文构建逻辑并修正拼写错误
This commit is contained in:
parent
530fb50352
commit
110efd2a21
|
|
@ -26,17 +26,22 @@ class TrainContext:
|
|||
iteration: int = field(default=0)
|
||||
loss: float = field(default=0.0)
|
||||
|
||||
wolrd_size: int = field(default=1)
|
||||
world_size: int = field(default=1)
|
||||
rank: int = field(default=0)
|
||||
|
||||
|
||||
class TrainContextBuilder:
|
||||
def __init__(self, config: TrainConfig):
|
||||
self.config = config
|
||||
self._context: TrainContext = None
|
||||
self._context = TrainContext(
|
||||
model=config.model,
|
||||
optimizer=config.optimizer,
|
||||
scheduler=config.scheduler,
|
||||
world_size=get_world_size(),
|
||||
rank=get_rank(),
|
||||
)
|
||||
|
||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
self._context = TrainContext()
|
||||
if checkpoint is None:
|
||||
checkpoint = Checkpoint(
|
||||
optimizer_state=self.config.optimizer.state_dict(),
|
||||
|
|
@ -46,37 +51,12 @@ class TrainContextBuilder:
|
|||
# resume from the assigned checkpoint or assigned iteration
|
||||
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
||||
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
||||
self._context.optimizer.load_state_dict(checkpoint.optimizer_state)
|
||||
self._context.scheduler.load_state_dict(checkpoint.scheduler_state)
|
||||
|
||||
self._context.checkpoint = checkpoint
|
||||
return self
|
||||
|
||||
def with_optimizer(self) -> Self:
|
||||
optimizer = self.config.optimizer
|
||||
if self._context is None:
|
||||
raise RuntimeError("Must call with_checkpoint() before with_optimizer()")
|
||||
|
||||
if self._context.checkpoint and self._context.checkpoint.optimizer_state:
|
||||
optimizer.load_state_dict(self._context.checkpoint.optimizer_state)
|
||||
|
||||
self._context.optimizer = optimizer
|
||||
|
||||
if self._context.checkpoint:
|
||||
self._context.checkpoint.optimizer_state = optimizer.state_dict()
|
||||
|
||||
return self
|
||||
|
||||
def with_scheduler(self) -> Self:
|
||||
scheduler = self.config.scheduler
|
||||
if self._context.checkpoint and self._context.checkpoint.scheduler_state:
|
||||
scheduler.load_state_dict(self._context.checkpoint.scheduler_state)
|
||||
|
||||
self._context.scheduler = scheduler
|
||||
|
||||
if self._context.checkpoint:
|
||||
self._context.checkpoint.scheduler_state = scheduler.state_dict()
|
||||
|
||||
return self
|
||||
|
||||
def with_dataloader(self) -> Self:
|
||||
# fix: change batch level iteration to sample level offset
|
||||
config = self.config
|
||||
|
|
@ -110,10 +90,4 @@ class TrainContextBuilder:
|
|||
return self
|
||||
|
||||
def build(self) -> TrainContext:
|
||||
self._context.model = self.config.model
|
||||
|
||||
if self.config.nprocs > 1:
|
||||
self._context.wolrd_size = get_world_size()
|
||||
self._context.rank = get_rank()
|
||||
|
||||
return self._context
|
||||
|
|
@ -34,8 +34,6 @@ class Trainer:
|
|||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||
return (TrainContextBuilder(self.train_config)
|
||||
.with_checkpoint(checkpoint)
|
||||
.with_optimizer()
|
||||
.with_scheduler()
|
||||
.with_dataloader()
|
||||
.with_strategy()
|
||||
.build())
|
||||
|
|
|
|||
Loading…
Reference in New Issue