fix: 修复metric 保存时机的问题
This commit is contained in:
parent
e55b57d771
commit
e23a5ca426
|
|
@ -171,8 +171,7 @@ class MetricLoggerCallback(TrainCallback):
|
||||||
log_interval:int=10,
|
log_interval:int=10,
|
||||||
metrics:List[str]=None
|
metrics:List[str]=None
|
||||||
):
|
):
|
||||||
self.step_num = 0
|
self.last_log_iter = 0
|
||||||
self.last_save_step = 0
|
|
||||||
self.save_interval = save_interval
|
self.save_interval = save_interval
|
||||||
self.log_interval = log_interval
|
self.log_interval = log_interval
|
||||||
self.metrics = metrics or ['loss', 'lr']
|
self.metrics = metrics or ['loss', 'lr']
|
||||||
|
|
@ -214,17 +213,16 @@ class MetricLoggerCallback(TrainCallback):
|
||||||
f.write(json.dumps(log) + '\n')
|
f.write(json.dumps(log) + '\n')
|
||||||
|
|
||||||
def on_batch_end(self, context):
|
def on_batch_end(self, context):
|
||||||
if self.step_num % self.log_interval == 0:
|
if context.iteration % self.log_interval == 0:
|
||||||
log_data = self._get_log_data(context)
|
log_data = self._get_log_data(context)
|
||||||
self._add_log(log_data)
|
self._add_log(log_data)
|
||||||
|
|
||||||
if self.step_num - self.last_save_step >= self.save_interval:
|
if context.iteration - self.last_log_iter >= self.save_interval:
|
||||||
self._save_log(context.epoch, context.iteration)
|
self._save_log(context.epoch, context.iteration)
|
||||||
self.last_save_step = self.step_num
|
self.last_log_iter = context.iteration
|
||||||
|
|
||||||
self.step_num += 1
|
|
||||||
|
|
||||||
def on_train_end(self, context):
|
def on_train_end(self, context):
|
||||||
|
if context.iteration != self.last_log_iter:
|
||||||
self._save_log(context.epoch, context.iteration)
|
self._save_log(context.epoch, context.iteration)
|
||||||
|
|
||||||
def on_error(self, context):
|
def on_error(self, context):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue