From 6d5176a11cb5b71c9269cce16368e2ead37ecde6 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 29 Sep 2025 17:17:45 +0800 Subject: [PATCH] =?UTF-8?q?feat(khaosz/trainer):=20=E6=94=B9=E8=BF=9B?= =?UTF-8?q?=E8=B0=83=E5=BA=A6=E5=99=A8=E9=85=8D=E7=BD=AE=E9=AA=8C=E8=AF=81?= =?UTF-8?q?=E5=92=8C=E5=8A=A0=E8=BD=BD=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/callback.py | 4 ++-- khaosz/trainer/strategy.py | 20 +++----------------- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/khaosz/trainer/callback.py b/khaosz/trainer/callback.py index bb0a3ea..6505833 100644 --- a/khaosz/trainer/callback.py +++ b/khaosz/trainer/callback.py @@ -154,9 +154,9 @@ class SchedulerCallback(TrainerCallback): def on_train_begin(self, trainer: 'Trainer', **kwargs): checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) self.current_iter = len(checkpoint.loss_list) - + self.schedule_config.validate() lambda_scheduler_fn = SchedulerFactory.load_schedule_fn( - **self.schedule_config.get_kwargs() + self.schedule_config ) self.scheduler = LambdaLR( diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index eadaf1f..93a2fea 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -386,24 +386,10 @@ class SchedulerFactory: return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * decay_progress))) return cosine_schedule - + @staticmethod - def create_schedule(config: ScheduleConfig) -> Callable[[int], float]: - """ - Create schedule from configuration. - - Args: - config: Schedule configuration instance - - Returns: - Schedule function - """ - config.validate() - kwargs = config.get_kwargs() - return SchedulerFactory.load_schedule_fn(**kwargs) - - @staticmethod - def load_schedule_fn(**kwargs) -> Callable[[int], float]: + def load_schedule_fn(scedule_config: ScheduleConfig) -> Callable[[int], float]: + kwargs = scedule_config.get_kwargs() schedule_type = kwargs.pop("schedule_type") if schedule_type == "cosine":