fix: 修复metric 保存时机的问题

This commit is contained in:
ViperEkura 2026-03-16 20:07:36 +08:00
parent e55b57d771
commit e23a5ca426
1 changed files with 6 additions and 8 deletions

View File

@ -171,8 +171,7 @@ class MetricLoggerCallback(TrainCallback):
log_interval:int=10,
metrics:List[str]=None
):
self.step_num = 0
self.last_save_step = 0
self.last_log_iter = 0
self.save_interval = save_interval
self.log_interval = log_interval
self.metrics = metrics or ['loss', 'lr']
@ -214,18 +213,17 @@ class MetricLoggerCallback(TrainCallback):
f.write(json.dumps(log) + '\n')
def on_batch_end(self, context):
if self.step_num % self.log_interval == 0:
if context.iteration % self.log_interval == 0:
log_data = self._get_log_data(context)
self._add_log(log_data)
if self.step_num - self.last_save_step >= self.save_interval:
if context.iteration - self.last_log_iter >= self.save_interval:
self._save_log(context.epoch, context.iteration)
self.last_save_step = self.step_num
self.step_num += 1
self.last_log_iter = context.iteration
def on_train_end(self, context):
self._save_log(context.epoch, context.iteration)
if context.iteration != self.last_log_iter:
self._save_log(context.epoch, context.iteration)
def on_error(self, context):
self._save_log(context.epoch, context.iteration)