feat(khaosz/trainer): 更新梯度裁剪回调

This commit is contained in:
ViperEkura 2025-10-20 13:30:26 +08:00
parent 05b012820b
commit 0db046f8d9
2 changed files with 5 additions and 2 deletions

View File

@ -60,9 +60,12 @@ class GradientClippingCallback(TrainCallback):
"""
Gradient clipping callback for trainer.
"""
def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm
def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'):
_ = context
clip_grad_norm_(trainer.parameter.model.parameters(), trainer.train_config.max_grad_norm)
clip_grad_norm_(trainer.parameter.model.parameters(), self.max_grad_norm)
class SchedulerCallback(TrainCallback):

View File

@ -35,7 +35,7 @@ class Trainer:
return [
ProgressBarCallback(),
CheckpointCallback(self.train_config.checkpoint_interval),
GradientClippingCallback(),
GradientClippingCallback(self.train_config.max_grad_norm),
SchedulerCallback(self.schedule_config),
]