From 82d22c57423214f5ba498ddd32899ed8940d26c9 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 6 Mar 2026 10:51:22 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dcallback=20=E6=97=B6?= =?UTF-8?q?=E6=9C=BA=E4=B8=8D=E4=B8=80=E8=87=B4=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/__init__.py | 4 ++-- khaosz/trainer/train_callback.py | 4 ++-- khaosz/trainer/trainer.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/khaosz/trainer/__init__.py b/khaosz/trainer/__init__.py index d856750..bbe99ee 100644 --- a/khaosz/trainer/__init__.py +++ b/khaosz/trainer/__init__.py @@ -8,7 +8,7 @@ from khaosz.trainer.train_callback import ( CheckpointCallback, TrainCallback, SchedulerCallback, - StepMonitorCallback + MetricLoggerCallback ) __all__ = [ @@ -25,5 +25,5 @@ __all__ = [ "CheckpointCallback", "TrainCallback", "SchedulerCallback", - "StepMonitorCallback" + "MetricLoggerCallback" ] \ No newline at end of file diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 6a161a8..bc1a504 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -163,7 +163,7 @@ class ProgressBarCallback(TrainCallback): self.progress_bar.close() -class StepMonitorCallback(TrainCallback): +class MetricLoggerCallback(TrainCallback): def __init__( self, log_dir:str, @@ -213,7 +213,7 @@ class StepMonitorCallback(TrainCallback): for log in self.log_cache: f.write(json.dumps(log) + '\n') - def on_step_end(self, context): + def on_batch_end(self, context): if self.step_num % self.log_interval == 0: log_data = self._get_log_data(context) self._add_log(log_data) diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index b997956..823511c 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -5,7 +5,7 @@ from khaosz.trainer.train_callback import ( TrainCallback, ProgressBarCallback, CheckpointCallback, - StepMonitorCallback, + MetricLoggerCallback, GradientClippingCallback, SchedulerCallback ) @@ -31,7 +31,7 @@ class Trainer: return [ ProgressBarCallback(train_config.n_epoch), CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), - StepMonitorCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), + MetricLoggerCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), GradientClippingCallback(train_config.max_grad_norm), SchedulerCallback(), ]