fix: 修复callback 时机不一致的问题
This commit is contained in:
parent
96744ac2d2
commit
82d22c5742
|
|
@ -8,7 +8,7 @@ from khaosz.trainer.train_callback import (
|
||||||
CheckpointCallback,
|
CheckpointCallback,
|
||||||
TrainCallback,
|
TrainCallback,
|
||||||
SchedulerCallback,
|
SchedulerCallback,
|
||||||
StepMonitorCallback
|
MetricLoggerCallback
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -25,5 +25,5 @@ __all__ = [
|
||||||
"CheckpointCallback",
|
"CheckpointCallback",
|
||||||
"TrainCallback",
|
"TrainCallback",
|
||||||
"SchedulerCallback",
|
"SchedulerCallback",
|
||||||
"StepMonitorCallback"
|
"MetricLoggerCallback"
|
||||||
]
|
]
|
||||||
|
|
@ -163,7 +163,7 @@ class ProgressBarCallback(TrainCallback):
|
||||||
self.progress_bar.close()
|
self.progress_bar.close()
|
||||||
|
|
||||||
|
|
||||||
class StepMonitorCallback(TrainCallback):
|
class MetricLoggerCallback(TrainCallback):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
log_dir:str,
|
log_dir:str,
|
||||||
|
|
@ -213,7 +213,7 @@ class StepMonitorCallback(TrainCallback):
|
||||||
for log in self.log_cache:
|
for log in self.log_cache:
|
||||||
f.write(json.dumps(log) + '\n')
|
f.write(json.dumps(log) + '\n')
|
||||||
|
|
||||||
def on_step_end(self, context):
|
def on_batch_end(self, context):
|
||||||
if self.step_num % self.log_interval == 0:
|
if self.step_num % self.log_interval == 0:
|
||||||
log_data = self._get_log_data(context)
|
log_data = self._get_log_data(context)
|
||||||
self._add_log(log_data)
|
self._add_log(log_data)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from khaosz.trainer.train_callback import (
|
||||||
TrainCallback,
|
TrainCallback,
|
||||||
ProgressBarCallback,
|
ProgressBarCallback,
|
||||||
CheckpointCallback,
|
CheckpointCallback,
|
||||||
StepMonitorCallback,
|
MetricLoggerCallback,
|
||||||
GradientClippingCallback,
|
GradientClippingCallback,
|
||||||
SchedulerCallback
|
SchedulerCallback
|
||||||
)
|
)
|
||||||
|
|
@ -31,7 +31,7 @@ class Trainer:
|
||||||
return [
|
return [
|
||||||
ProgressBarCallback(train_config.n_epoch),
|
ProgressBarCallback(train_config.n_epoch),
|
||||||
CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
|
CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
|
||||||
StepMonitorCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
|
MetricLoggerCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
|
||||||
GradientClippingCallback(train_config.max_grad_norm),
|
GradientClippingCallback(train_config.max_grad_norm),
|
||||||
SchedulerCallback(),
|
SchedulerCallback(),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue