fix(khaosz/trainer): 修正训练步数计算逻辑
This commit is contained in:
parent
c104a400e7
commit
e467420475
|
|
@ -154,6 +154,11 @@ class SchedulerCallback(TrainerCallback):
|
|||
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
||||
self.current_iter = len(checkpoint.loss_list)
|
||||
|
||||
for group in trainer.train_config.optimizer.param_groups:
|
||||
if "initial_lr" not in group:
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
self.schedule_config.validate()
|
||||
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
||||
self.schedule_config
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import itertools
|
||||
from typing import Optional, List
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
|
|
@ -38,16 +39,21 @@ class Trainer:
|
|||
SchedulerCallback(self.schedule_config),
|
||||
]
|
||||
|
||||
def _create_dataloader(self) -> DataLoader:
|
||||
def _create_dataloader(self, start_index: int = 0) -> DataLoader:
|
||||
seed = self.train_config.random_seed
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
sampler = RandomSampler(self.train_config.dataset, generator=generator)
|
||||
return DataLoader(
|
||||
dataloader = DataLoader(
|
||||
self.train_config.dataset,
|
||||
batch_size=self.train_config.batch_size,
|
||||
sampler=sampler
|
||||
)
|
||||
|
||||
if start_index > 0:
|
||||
dataloader = itertools.islice(dataloader, start_index, None)
|
||||
|
||||
return dataloader
|
||||
|
||||
def _call_callbacks(self, method_name: str, **kwargs):
|
||||
for callback in self.callbacks:
|
||||
method = getattr(callback, method_name, None)
|
||||
|
|
@ -58,34 +64,29 @@ class Trainer:
|
|||
self,
|
||||
train_checkpoint: Optional[Checkpoint] = None
|
||||
) -> Checkpoint:
|
||||
assert self.schedule_config.schedule_type in ["cosine", "sgdr"]
|
||||
|
||||
if train_checkpoint:
|
||||
self.checkpoint = train_checkpoint
|
||||
self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
|
||||
|
||||
else:
|
||||
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
|
||||
|
||||
current_iter = len(self.checkpoint.loss_list)
|
||||
total_steps_per_epoch = len(self.train_config.dataset) // self.train_config.batch_size
|
||||
total_reamining_steps = total_steps_per_epoch * self.train_config.n_epoch - current_iter
|
||||
|
||||
for group in self.train_config.optimizer.param_groups:
|
||||
if "initial_lr" not in group:
|
||||
group["initial_lr"] = group["lr"]
|
||||
current_epochs = total_reamining_steps // total_steps_per_epoch
|
||||
current_steps = total_reamining_steps % total_steps_per_epoch
|
||||
|
||||
reamining_steps = self.train_config.n_epoch - current_iter
|
||||
total_steps = len(self.train_config.dataset) // self.train_config.batch_size
|
||||
remaining_epochs = (reamining_steps + total_steps - 1) // total_steps
|
||||
# train
|
||||
self._call_callbacks('on_train_begin', checkpoint=self.checkpoint)
|
||||
|
||||
try:
|
||||
for epoch in range(remaining_epochs):
|
||||
self.checkpoint.model.train()
|
||||
|
||||
try:
|
||||
for epoch in range(current_epochs):
|
||||
# epoch
|
||||
self._call_callbacks('on_epoch_begin', epoch=epoch)
|
||||
|
||||
dataloader = self._create_dataloader()
|
||||
|
||||
dataloader = self._create_dataloader(start_index=current_steps)
|
||||
for batch in dataloader:
|
||||
# batch
|
||||
self._call_callbacks('on_batch_begin', batch=batch)
|
||||
|
|
@ -110,5 +111,4 @@ class Trainer:
|
|||
|
||||
finally:
|
||||
self._call_callbacks('on_train_end', checkpoint=self.checkpoint)
|
||||
|
||||
return self.checkpoint
|
||||
|
|
@ -204,9 +204,9 @@ def test_checkpoint_train(test_env):
|
|||
dataset=dataset,
|
||||
optimizer=optimizer,
|
||||
checkpoint_dir=test_env["test_dir"],
|
||||
n_epoch=2,
|
||||
n_epoch=1,
|
||||
batch_size=2,
|
||||
checkpoint_interval=5,
|
||||
checkpoint_interval=1,
|
||||
accumulation_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
random_seed=42
|
||||
|
|
@ -218,13 +218,18 @@ def test_checkpoint_train(test_env):
|
|||
pad_token_id=test_env["tokenizer"].pad_id
|
||||
)
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=100,
|
||||
total_steps=1000
|
||||
warmup_steps=1,
|
||||
total_steps=5
|
||||
)
|
||||
trainer = Trainer(param, train_config, schedule_config)
|
||||
|
||||
checkpoint = None
|
||||
|
||||
try:
|
||||
trainer.train()
|
||||
checkpoint = trainer.train()
|
||||
except Exception:
|
||||
checkpoint = trainer.checkpoint
|
||||
trainer.train(train_checkpoint=checkpoint)
|
||||
pass
|
||||
|
||||
checkpoint = trainer.train(train_checkpoint=checkpoint)
|
||||
assert len(checkpoint.loss_list) == 5 - 1
|
||||
|
||||
Loading…
Reference in New Issue