fix(khaosz/trainer): 调整训练循环中回调调用顺序并增强异常日志记录
This commit is contained in:
parent
dd6a9e4ede
commit
288e2c3da6
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from typing import Optional, List, cast
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
|
@ -12,6 +13,7 @@ from khaosz.trainer.trainer_callback import (
|
|||
SchedulerCallback
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Trainer:
|
||||
def __init__(
|
||||
|
|
@ -115,12 +117,6 @@ class Trainer:
|
|||
train_kwargs["epoch"] = epoch
|
||||
self._call_callbacks('on_epoch_begin', **train_kwargs)
|
||||
for batch in dataloader:
|
||||
# batch
|
||||
self._call_callbacks('on_batch_begin', **train_kwargs)
|
||||
loss = self.train_config.strategy(batch)
|
||||
loss.backward()
|
||||
train_kwargs["loss"] = loss.item()
|
||||
self._call_callbacks('on_batch_end', **train_kwargs)
|
||||
|
||||
if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0:
|
||||
# step
|
||||
|
|
@ -128,13 +124,21 @@ class Trainer:
|
|||
self.train_config.optimizer.step()
|
||||
self.train_config.optimizer.zero_grad()
|
||||
self._call_callbacks('on_step_end', **train_kwargs)
|
||||
|
||||
|
||||
# batch
|
||||
self._call_callbacks('on_batch_begin', **train_kwargs)
|
||||
loss = self.train_config.strategy(batch)
|
||||
train_kwargs["loss"] = loss.item()
|
||||
train_kwargs["current_iter"] += 1
|
||||
|
||||
loss.backward()
|
||||
|
||||
self._call_callbacks('on_batch_end', **train_kwargs)
|
||||
|
||||
self._call_callbacks('on_epoch_end', **train_kwargs)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
self._call_callbacks('on_train_end', **train_kwargs)
|
||||
return checkpoint
|
||||
|
|
@ -87,10 +87,12 @@ class ProgressBarCallback(TrainerCallback):
|
|||
)
|
||||
|
||||
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||
_ = trainer
|
||||
loss = kwargs.get('loss')
|
||||
optimizer = cast(optim.Optimizer, kwargs.get('optimizer'))
|
||||
self.progress_bar.set_postfix({
|
||||
"loss": f"{loss:.4f}",
|
||||
"lr": f"{trainer.train_config.optimizer.param_groups[0]['lr']:.2e}"
|
||||
"lr": f"{optimizer.param_groups[-1]['lr']:.2e}"
|
||||
})
|
||||
self.progress_bar.update(1)
|
||||
|
||||
|
|
@ -180,7 +182,7 @@ class SchedulerCallback(TrainerCallback):
|
|||
last_epoch=self.current_iter - 1
|
||||
)
|
||||
|
||||
def on_step_end(self, trainer: 'Trainer', **kwargs):
|
||||
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||
_ = trainer, kwargs
|
||||
|
||||
if self.scheduler:
|
||||
|
|
|
|||
Loading…
Reference in New Issue