From e23a5ca4267423e83f3bd169cb0ce4778179d711 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 16 Mar 2026 20:07:36 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dmetric=20=E4=BF=9D?= =?UTF-8?q?=E5=AD=98=E6=97=B6=E6=9C=BA=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/train_callback.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index bc1a504..5fa9719 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -171,8 +171,7 @@ class MetricLoggerCallback(TrainCallback): log_interval:int=10, metrics:List[str]=None ): - self.step_num = 0 - self.last_save_step = 0 + self.last_log_iter = 0 self.save_interval = save_interval self.log_interval = log_interval self.metrics = metrics or ['loss', 'lr'] @@ -214,18 +213,17 @@ class MetricLoggerCallback(TrainCallback): f.write(json.dumps(log) + '\n') def on_batch_end(self, context): - if self.step_num % self.log_interval == 0: + if context.iteration % self.log_interval == 0: log_data = self._get_log_data(context) self._add_log(log_data) - if self.step_num - self.last_save_step >= self.save_interval: + if context.iteration - self.last_log_iter >= self.save_interval: self._save_log(context.epoch, context.iteration) - self.last_save_step = self.step_num - - self.step_num += 1 + self.last_log_iter = context.iteration def on_train_end(self, context): - self._save_log(context.epoch, context.iteration) + if context.iteration != self.last_log_iter: + self._save_log(context.epoch, context.iteration) def on_error(self, context): self._save_log(context.epoch, context.iteration)