From 0db046f8d99399c03ce17ef4c812409936c9599e Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 20 Oct 2025 13:30:26 +0800 Subject: [PATCH] =?UTF-8?q?feat(khaosz/trainer):=20=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E6=A2=AF=E5=BA=A6=E8=A3=81=E5=89=AA=E5=9B=9E=E8=B0=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/train_callback.py | 5 ++++- khaosz/trainer/trainer.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 1673221..4805e37 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -60,9 +60,12 @@ class GradientClippingCallback(TrainCallback): """ Gradient clipping callback for trainer. """ + def __init__(self, max_grad_norm: float): + self.max_grad_norm = max_grad_norm + def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'): _ = context - clip_grad_norm_(trainer.parameter.model.parameters(), trainer.train_config.max_grad_norm) + clip_grad_norm_(trainer.parameter.model.parameters(), self.max_grad_norm) class SchedulerCallback(TrainCallback): diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 1565584..be28e8c 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -35,7 +35,7 @@ class Trainer: return [ ProgressBarCallback(), CheckpointCallback(self.train_config.checkpoint_interval), - GradientClippingCallback(), + GradientClippingCallback(self.train_config.max_grad_norm), SchedulerCallback(self.schedule_config), ]