fix: 修复callback 时机不一致的问题
This commit is contained in:
parent
96744ac2d2
commit
82d22c5742
|
|
@ -8,7 +8,7 @@ from khaosz.trainer.train_callback import (
|
|||
CheckpointCallback,
|
||||
TrainCallback,
|
||||
SchedulerCallback,
|
||||
StepMonitorCallback
|
||||
MetricLoggerCallback
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -25,5 +25,5 @@ __all__ = [
|
|||
"CheckpointCallback",
|
||||
"TrainCallback",
|
||||
"SchedulerCallback",
|
||||
"StepMonitorCallback"
|
||||
"MetricLoggerCallback"
|
||||
]
|
||||
|
|
@ -163,7 +163,7 @@ class ProgressBarCallback(TrainCallback):
|
|||
self.progress_bar.close()
|
||||
|
||||
|
||||
class StepMonitorCallback(TrainCallback):
|
||||
class MetricLoggerCallback(TrainCallback):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir:str,
|
||||
|
|
@ -213,7 +213,7 @@ class StepMonitorCallback(TrainCallback):
|
|||
for log in self.log_cache:
|
||||
f.write(json.dumps(log) + '\n')
|
||||
|
||||
def on_step_end(self, context):
|
||||
def on_batch_end(self, context):
|
||||
if self.step_num % self.log_interval == 0:
|
||||
log_data = self._get_log_data(context)
|
||||
self._add_log(log_data)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from khaosz.trainer.train_callback import (
|
|||
TrainCallback,
|
||||
ProgressBarCallback,
|
||||
CheckpointCallback,
|
||||
StepMonitorCallback,
|
||||
MetricLoggerCallback,
|
||||
GradientClippingCallback,
|
||||
SchedulerCallback
|
||||
)
|
||||
|
|
@ -31,7 +31,7 @@ class Trainer:
|
|||
return [
|
||||
ProgressBarCallback(train_config.n_epoch),
|
||||
CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
|
||||
StepMonitorCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
|
||||
MetricLoggerCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
|
||||
GradientClippingCallback(train_config.max_grad_norm),
|
||||
SchedulerCallback(),
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue