fix: 修复callback 时机不一致的问题

This commit is contained in:
ViperEkura 2026-03-06 10:51:22 +08:00
parent 96744ac2d2
commit 82d22c5742
3 changed files with 6 additions and 6 deletions

View File

@ -8,7 +8,7 @@ from khaosz.trainer.train_callback import (
CheckpointCallback, CheckpointCallback,
TrainCallback, TrainCallback,
SchedulerCallback, SchedulerCallback,
StepMonitorCallback MetricLoggerCallback
) )
__all__ = [ __all__ = [
@ -25,5 +25,5 @@ __all__ = [
"CheckpointCallback", "CheckpointCallback",
"TrainCallback", "TrainCallback",
"SchedulerCallback", "SchedulerCallback",
"StepMonitorCallback" "MetricLoggerCallback"
] ]

View File

@ -163,7 +163,7 @@ class ProgressBarCallback(TrainCallback):
self.progress_bar.close() self.progress_bar.close()
class StepMonitorCallback(TrainCallback): class MetricLoggerCallback(TrainCallback):
def __init__( def __init__(
self, self,
log_dir:str, log_dir:str,
@ -213,7 +213,7 @@ class StepMonitorCallback(TrainCallback):
for log in self.log_cache: for log in self.log_cache:
f.write(json.dumps(log) + '\n') f.write(json.dumps(log) + '\n')
def on_step_end(self, context): def on_batch_end(self, context):
if self.step_num % self.log_interval == 0: if self.step_num % self.log_interval == 0:
log_data = self._get_log_data(context) log_data = self._get_log_data(context)
self._add_log(log_data) self._add_log(log_data)

View File

@ -5,7 +5,7 @@ from khaosz.trainer.train_callback import (
TrainCallback, TrainCallback,
ProgressBarCallback, ProgressBarCallback,
CheckpointCallback, CheckpointCallback,
StepMonitorCallback, MetricLoggerCallback,
GradientClippingCallback, GradientClippingCallback,
SchedulerCallback SchedulerCallback
) )
@ -31,7 +31,7 @@ class Trainer:
return [ return [
ProgressBarCallback(train_config.n_epoch), ProgressBarCallback(train_config.n_epoch),
CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
StepMonitorCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), MetricLoggerCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
GradientClippingCallback(train_config.max_grad_norm), GradientClippingCallback(train_config.max_grad_norm),
SchedulerCallback(), SchedulerCallback(),
] ]