fix(khaosz/trainer): 修正训练步数计算逻辑
This commit is contained in:
parent
c104a400e7
commit
e467420475
|
|
@ -154,6 +154,11 @@ class SchedulerCallback(TrainerCallback):
|
||||||
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
||||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
||||||
self.current_iter = len(checkpoint.loss_list)
|
self.current_iter = len(checkpoint.loss_list)
|
||||||
|
|
||||||
|
for group in trainer.train_config.optimizer.param_groups:
|
||||||
|
if "initial_lr" not in group:
|
||||||
|
group["initial_lr"] = group["lr"]
|
||||||
|
|
||||||
self.schedule_config.validate()
|
self.schedule_config.validate()
|
||||||
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
||||||
self.schedule_config
|
self.schedule_config
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
import itertools
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
|
|
||||||
|
|
@ -38,16 +39,21 @@ class Trainer:
|
||||||
SchedulerCallback(self.schedule_config),
|
SchedulerCallback(self.schedule_config),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _create_dataloader(self) -> DataLoader:
|
def _create_dataloader(self, start_index: int = 0) -> DataLoader:
|
||||||
seed = self.train_config.random_seed
|
seed = self.train_config.random_seed
|
||||||
generator = torch.Generator().manual_seed(seed)
|
generator = torch.Generator().manual_seed(seed)
|
||||||
sampler = RandomSampler(self.train_config.dataset, generator=generator)
|
sampler = RandomSampler(self.train_config.dataset, generator=generator)
|
||||||
return DataLoader(
|
dataloader = DataLoader(
|
||||||
self.train_config.dataset,
|
self.train_config.dataset,
|
||||||
batch_size=self.train_config.batch_size,
|
batch_size=self.train_config.batch_size,
|
||||||
sampler=sampler
|
sampler=sampler
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if start_index > 0:
|
||||||
|
dataloader = itertools.islice(dataloader, start_index, None)
|
||||||
|
|
||||||
|
return dataloader
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, **kwargs):
|
def _call_callbacks(self, method_name: str, **kwargs):
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
method = getattr(callback, method_name, None)
|
method = getattr(callback, method_name, None)
|
||||||
|
|
@ -58,34 +64,29 @@ class Trainer:
|
||||||
self,
|
self,
|
||||||
train_checkpoint: Optional[Checkpoint] = None
|
train_checkpoint: Optional[Checkpoint] = None
|
||||||
) -> Checkpoint:
|
) -> Checkpoint:
|
||||||
assert self.schedule_config.schedule_type in ["cosine", "sgdr"]
|
|
||||||
|
|
||||||
if train_checkpoint:
|
if train_checkpoint:
|
||||||
self.checkpoint = train_checkpoint
|
self.checkpoint = train_checkpoint
|
||||||
self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
|
self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
|
||||||
|
else:
|
||||||
|
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
|
||||||
|
|
||||||
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
|
|
||||||
current_iter = len(self.checkpoint.loss_list)
|
current_iter = len(self.checkpoint.loss_list)
|
||||||
|
total_steps_per_epoch = len(self.train_config.dataset) // self.train_config.batch_size
|
||||||
|
total_reamining_steps = total_steps_per_epoch * self.train_config.n_epoch - current_iter
|
||||||
|
|
||||||
for group in self.train_config.optimizer.param_groups:
|
current_epochs = total_reamining_steps // total_steps_per_epoch
|
||||||
if "initial_lr" not in group:
|
current_steps = total_reamining_steps % total_steps_per_epoch
|
||||||
group["initial_lr"] = group["lr"]
|
|
||||||
|
|
||||||
reamining_steps = self.train_config.n_epoch - current_iter
|
|
||||||
total_steps = len(self.train_config.dataset) // self.train_config.batch_size
|
|
||||||
remaining_epochs = (reamining_steps + total_steps - 1) // total_steps
|
|
||||||
# train
|
# train
|
||||||
self._call_callbacks('on_train_begin', checkpoint=self.checkpoint)
|
self._call_callbacks('on_train_begin', checkpoint=self.checkpoint)
|
||||||
|
self.checkpoint.model.train()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for epoch in range(remaining_epochs):
|
for epoch in range(current_epochs):
|
||||||
self.checkpoint.model.train()
|
|
||||||
|
|
||||||
# epoch
|
# epoch
|
||||||
self._call_callbacks('on_epoch_begin', epoch=epoch)
|
self._call_callbacks('on_epoch_begin', epoch=epoch)
|
||||||
|
dataloader = self._create_dataloader(start_index=current_steps)
|
||||||
dataloader = self._create_dataloader()
|
|
||||||
|
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
# batch
|
# batch
|
||||||
self._call_callbacks('on_batch_begin', batch=batch)
|
self._call_callbacks('on_batch_begin', batch=batch)
|
||||||
|
|
@ -110,5 +111,4 @@ class Trainer:
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
self._call_callbacks('on_train_end', checkpoint=self.checkpoint)
|
self._call_callbacks('on_train_end', checkpoint=self.checkpoint)
|
||||||
|
return self.checkpoint
|
||||||
return self.checkpoint
|
|
||||||
|
|
@ -204,9 +204,9 @@ def test_checkpoint_train(test_env):
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
checkpoint_dir=test_env["test_dir"],
|
checkpoint_dir=test_env["test_dir"],
|
||||||
n_epoch=2,
|
n_epoch=1,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
checkpoint_interval=5,
|
checkpoint_interval=1,
|
||||||
accumulation_steps=1,
|
accumulation_steps=1,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=42
|
random_seed=42
|
||||||
|
|
@ -218,13 +218,18 @@ def test_checkpoint_train(test_env):
|
||||||
pad_token_id=test_env["tokenizer"].pad_id
|
pad_token_id=test_env["tokenizer"].pad_id
|
||||||
)
|
)
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(
|
||||||
warmup_steps=100,
|
warmup_steps=1,
|
||||||
total_steps=1000
|
total_steps=5
|
||||||
)
|
)
|
||||||
trainer = Trainer(param, train_config, schedule_config)
|
trainer = Trainer(param, train_config, schedule_config)
|
||||||
|
|
||||||
|
checkpoint = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
trainer.train()
|
checkpoint = trainer.train()
|
||||||
except Exception:
|
except Exception:
|
||||||
checkpoint = trainer.checkpoint
|
pass
|
||||||
trainer.train(train_checkpoint=checkpoint)
|
|
||||||
|
checkpoint = trainer.train(train_checkpoint=checkpoint)
|
||||||
|
assert len(checkpoint.loss_list) == 5 - 1
|
||||||
|
|
||||||
Loading…
Reference in New Issue