fix: 修复callback 时机不一致的问题

This commit is contained in:
ViperEkura 2026-03-06 10:51:22 +08:00
parent 96744ac2d2
commit 82d22c5742
3 changed files with 6 additions and 6 deletions

View File

@ -8,7 +8,7 @@ from khaosz.trainer.train_callback import (
CheckpointCallback,
TrainCallback,
SchedulerCallback,
StepMonitorCallback
MetricLoggerCallback
)
__all__ = [
@ -25,5 +25,5 @@ __all__ = [
"CheckpointCallback",
"TrainCallback",
"SchedulerCallback",
"StepMonitorCallback"
"MetricLoggerCallback"
]

View File

@ -163,7 +163,7 @@ class ProgressBarCallback(TrainCallback):
self.progress_bar.close()
class StepMonitorCallback(TrainCallback):
class MetricLoggerCallback(TrainCallback):
def __init__(
self,
log_dir:str,
@ -213,7 +213,7 @@ class StepMonitorCallback(TrainCallback):
for log in self.log_cache:
f.write(json.dumps(log) + '\n')
def on_step_end(self, context):
def on_batch_end(self, context):
if self.step_num % self.log_interval == 0:
log_data = self._get_log_data(context)
self._add_log(log_data)

View File

@ -5,7 +5,7 @@ from khaosz.trainer.train_callback import (
TrainCallback,
ProgressBarCallback,
CheckpointCallback,
StepMonitorCallback,
MetricLoggerCallback,
GradientClippingCallback,
SchedulerCallback
)
@ -31,7 +31,7 @@ class Trainer:
return [
ProgressBarCallback(train_config.n_epoch),
CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
StepMonitorCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
MetricLoggerCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
GradientClippingCallback(train_config.max_grad_norm),
SchedulerCallback(),
]