fix(khaosz/trainer): 调整训练循环中回调调用顺序并增强异常日志记录
This commit is contained in:
parent
dd6a9e4ede
commit
288e2c3da6
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging
|
||||||
from typing import Optional, List, cast
|
from typing import Optional, List, cast
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
@ -12,6 +13,7 @@ from khaosz.trainer.trainer_callback import (
|
||||||
SchedulerCallback
|
SchedulerCallback
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -115,12 +117,6 @@ class Trainer:
|
||||||
train_kwargs["epoch"] = epoch
|
train_kwargs["epoch"] = epoch
|
||||||
self._call_callbacks('on_epoch_begin', **train_kwargs)
|
self._call_callbacks('on_epoch_begin', **train_kwargs)
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
# batch
|
|
||||||
self._call_callbacks('on_batch_begin', **train_kwargs)
|
|
||||||
loss = self.train_config.strategy(batch)
|
|
||||||
loss.backward()
|
|
||||||
train_kwargs["loss"] = loss.item()
|
|
||||||
self._call_callbacks('on_batch_end', **train_kwargs)
|
|
||||||
|
|
||||||
if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0:
|
if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0:
|
||||||
# step
|
# step
|
||||||
|
|
@ -128,13 +124,21 @@ class Trainer:
|
||||||
self.train_config.optimizer.step()
|
self.train_config.optimizer.step()
|
||||||
self.train_config.optimizer.zero_grad()
|
self.train_config.optimizer.zero_grad()
|
||||||
self._call_callbacks('on_step_end', **train_kwargs)
|
self._call_callbacks('on_step_end', **train_kwargs)
|
||||||
|
|
||||||
|
# batch
|
||||||
|
self._call_callbacks('on_batch_begin', **train_kwargs)
|
||||||
|
loss = self.train_config.strategy(batch)
|
||||||
|
train_kwargs["loss"] = loss.item()
|
||||||
train_kwargs["current_iter"] += 1
|
train_kwargs["current_iter"] += 1
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
self._call_callbacks('on_batch_end', **train_kwargs)
|
||||||
|
|
||||||
self._call_callbacks('on_epoch_end', **train_kwargs)
|
self._call_callbacks('on_epoch_end', **train_kwargs)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
||||||
|
raise
|
||||||
finally:
|
finally:
|
||||||
self._call_callbacks('on_train_end', **train_kwargs)
|
self._call_callbacks('on_train_end', **train_kwargs)
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
@ -87,10 +87,12 @@ class ProgressBarCallback(TrainerCallback):
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||||
|
_ = trainer
|
||||||
loss = kwargs.get('loss')
|
loss = kwargs.get('loss')
|
||||||
|
optimizer = cast(optim.Optimizer, kwargs.get('optimizer'))
|
||||||
self.progress_bar.set_postfix({
|
self.progress_bar.set_postfix({
|
||||||
"loss": f"{loss:.4f}",
|
"loss": f"{loss:.4f}",
|
||||||
"lr": f"{trainer.train_config.optimizer.param_groups[0]['lr']:.2e}"
|
"lr": f"{optimizer.param_groups[-1]['lr']:.2e}"
|
||||||
})
|
})
|
||||||
self.progress_bar.update(1)
|
self.progress_bar.update(1)
|
||||||
|
|
||||||
|
|
@ -180,7 +182,7 @@ class SchedulerCallback(TrainerCallback):
|
||||||
last_epoch=self.current_iter - 1
|
last_epoch=self.current_iter - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_step_end(self, trainer: 'Trainer', **kwargs):
|
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||||
_ = trainer, kwargs
|
_ = trainer, kwargs
|
||||||
|
|
||||||
if self.scheduler:
|
if self.scheduler:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue