fix: 修复metric 保存时机的问题
This commit is contained in:
parent
e55b57d771
commit
e23a5ca426
|
|
@ -171,8 +171,7 @@ class MetricLoggerCallback(TrainCallback):
|
|||
log_interval:int=10,
|
||||
metrics:List[str]=None
|
||||
):
|
||||
self.step_num = 0
|
||||
self.last_save_step = 0
|
||||
self.last_log_iter = 0
|
||||
self.save_interval = save_interval
|
||||
self.log_interval = log_interval
|
||||
self.metrics = metrics or ['loss', 'lr']
|
||||
|
|
@ -214,18 +213,17 @@ class MetricLoggerCallback(TrainCallback):
|
|||
f.write(json.dumps(log) + '\n')
|
||||
|
||||
def on_batch_end(self, context):
|
||||
if self.step_num % self.log_interval == 0:
|
||||
if context.iteration % self.log_interval == 0:
|
||||
log_data = self._get_log_data(context)
|
||||
self._add_log(log_data)
|
||||
|
||||
if self.step_num - self.last_save_step >= self.save_interval:
|
||||
if context.iteration - self.last_log_iter >= self.save_interval:
|
||||
self._save_log(context.epoch, context.iteration)
|
||||
self.last_save_step = self.step_num
|
||||
|
||||
self.step_num += 1
|
||||
self.last_log_iter = context.iteration
|
||||
|
||||
def on_train_end(self, context):
|
||||
self._save_log(context.epoch, context.iteration)
|
||||
if context.iteration != self.last_log_iter:
|
||||
self._save_log(context.epoch, context.iteration)
|
||||
|
||||
def on_error(self, context):
|
||||
self._save_log(context.epoch, context.iteration)
|
||||
|
|
|
|||
Loading…
Reference in New Issue