diff --git a/khaosz/trainer/__init__.py b/khaosz/trainer/__init__.py index d856750..bbe99ee 100644 --- a/khaosz/trainer/__init__.py +++ b/khaosz/trainer/__init__.py @@ -8,7 +8,7 @@ from khaosz.trainer.train_callback import ( CheckpointCallback, TrainCallback, SchedulerCallback, - StepMonitorCallback + MetricLoggerCallback ) __all__ = [ @@ -25,5 +25,5 @@ __all__ = [ "CheckpointCallback", "TrainCallback", "SchedulerCallback", - "StepMonitorCallback" + "MetricLoggerCallback" ] \ No newline at end of file diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 6a161a8..bc1a504 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -163,7 +163,7 @@ class ProgressBarCallback(TrainCallback): self.progress_bar.close() -class StepMonitorCallback(TrainCallback): +class MetricLoggerCallback(TrainCallback): def __init__( self, log_dir:str, @@ -213,7 +213,7 @@ class StepMonitorCallback(TrainCallback): for log in self.log_cache: f.write(json.dumps(log) + '\n') - def on_step_end(self, context): + def on_batch_end(self, context): if self.step_num % self.log_interval == 0: log_data = self._get_log_data(context) self._add_log(log_data) diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index b997956..823511c 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -5,7 +5,7 @@ from khaosz.trainer.train_callback import ( TrainCallback, ProgressBarCallback, CheckpointCallback, - StepMonitorCallback, + MetricLoggerCallback, GradientClippingCallback, SchedulerCallback ) @@ -31,7 +31,7 @@ class Trainer: return [ ProgressBarCallback(train_config.n_epoch), CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), - StepMonitorCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), + MetricLoggerCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), GradientClippingCallback(train_config.max_grad_norm), SchedulerCallback(), ]