fix(trainer): 修复训练上下文构建逻辑并修正拼写错误
This commit is contained in:
parent
530fb50352
commit
110efd2a21
|
|
@ -26,17 +26,22 @@ class TrainContext:
|
||||||
iteration: int = field(default=0)
|
iteration: int = field(default=0)
|
||||||
loss: float = field(default=0.0)
|
loss: float = field(default=0.0)
|
||||||
|
|
||||||
wolrd_size: int = field(default=1)
|
world_size: int = field(default=1)
|
||||||
rank: int = field(default=0)
|
rank: int = field(default=0)
|
||||||
|
|
||||||
|
|
||||||
class TrainContextBuilder:
|
class TrainContextBuilder:
|
||||||
def __init__(self, config: TrainConfig):
|
def __init__(self, config: TrainConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._context: TrainContext = None
|
self._context = TrainContext(
|
||||||
|
model=config.model,
|
||||||
|
optimizer=config.optimizer,
|
||||||
|
scheduler=config.scheduler,
|
||||||
|
world_size=get_world_size(),
|
||||||
|
rank=get_rank(),
|
||||||
|
)
|
||||||
|
|
||||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||||
self._context = TrainContext()
|
|
||||||
if checkpoint is None:
|
if checkpoint is None:
|
||||||
checkpoint = Checkpoint(
|
checkpoint = Checkpoint(
|
||||||
optimizer_state=self.config.optimizer.state_dict(),
|
optimizer_state=self.config.optimizer.state_dict(),
|
||||||
|
|
@ -46,37 +51,12 @@ class TrainContextBuilder:
|
||||||
# resume from the assigned checkpoint or assigned iteration
|
# resume from the assigned checkpoint or assigned iteration
|
||||||
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
||||||
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
||||||
|
self._context.optimizer.load_state_dict(checkpoint.optimizer_state)
|
||||||
|
self._context.scheduler.load_state_dict(checkpoint.scheduler_state)
|
||||||
|
|
||||||
self._context.checkpoint = checkpoint
|
self._context.checkpoint = checkpoint
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_optimizer(self) -> Self:
|
|
||||||
optimizer = self.config.optimizer
|
|
||||||
if self._context is None:
|
|
||||||
raise RuntimeError("Must call with_checkpoint() before with_optimizer()")
|
|
||||||
|
|
||||||
if self._context.checkpoint and self._context.checkpoint.optimizer_state:
|
|
||||||
optimizer.load_state_dict(self._context.checkpoint.optimizer_state)
|
|
||||||
|
|
||||||
self._context.optimizer = optimizer
|
|
||||||
|
|
||||||
if self._context.checkpoint:
|
|
||||||
self._context.checkpoint.optimizer_state = optimizer.state_dict()
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def with_scheduler(self) -> Self:
|
|
||||||
scheduler = self.config.scheduler
|
|
||||||
if self._context.checkpoint and self._context.checkpoint.scheduler_state:
|
|
||||||
scheduler.load_state_dict(self._context.checkpoint.scheduler_state)
|
|
||||||
|
|
||||||
self._context.scheduler = scheduler
|
|
||||||
|
|
||||||
if self._context.checkpoint:
|
|
||||||
self._context.checkpoint.scheduler_state = scheduler.state_dict()
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def with_dataloader(self) -> Self:
|
def with_dataloader(self) -> Self:
|
||||||
# fix: change batch level iteration to sample level offset
|
# fix: change batch level iteration to sample level offset
|
||||||
config = self.config
|
config = self.config
|
||||||
|
|
@ -110,10 +90,4 @@ class TrainContextBuilder:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def build(self) -> TrainContext:
|
def build(self) -> TrainContext:
|
||||||
self._context.model = self.config.model
|
|
||||||
|
|
||||||
if self.config.nprocs > 1:
|
|
||||||
self._context.wolrd_size = get_world_size()
|
|
||||||
self._context.rank = get_rank()
|
|
||||||
|
|
||||||
return self._context
|
return self._context
|
||||||
|
|
@ -34,8 +34,6 @@ class Trainer:
|
||||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||||
return (TrainContextBuilder(self.train_config)
|
return (TrainContextBuilder(self.train_config)
|
||||||
.with_checkpoint(checkpoint)
|
.with_checkpoint(checkpoint)
|
||||||
.with_optimizer()
|
|
||||||
.with_scheduler()
|
|
||||||
.with_dataloader()
|
.with_dataloader()
|
||||||
.with_strategy()
|
.with_strategy()
|
||||||
.build())
|
.build())
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue