feat(khaosz/trainer): 改进调度器配置验证和加载逻辑

This commit is contained in:
ViperEkura 2025-09-29 17:17:45 +08:00
parent bdda1cc35a
commit 6d5176a11c
2 changed files with 5 additions and 19 deletions

View File

@ -154,9 +154,9 @@ class SchedulerCallback(TrainerCallback):
def on_train_begin(self, trainer: 'Trainer', **kwargs):
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
self.current_iter = len(checkpoint.loss_list)
self.schedule_config.validate()
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
**self.schedule_config.get_kwargs()
self.schedule_config
)
self.scheduler = LambdaLR(

View File

@ -386,24 +386,10 @@ class SchedulerFactory:
return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * decay_progress)))
return cosine_schedule
@staticmethod
def create_schedule(config: ScheduleConfig) -> Callable[[int], float]:
"""
Create schedule from configuration.
Args:
config: Schedule configuration instance
Returns:
Schedule function
"""
config.validate()
kwargs = config.get_kwargs()
return SchedulerFactory.load_schedule_fn(**kwargs)
@staticmethod
def load_schedule_fn(**kwargs) -> Callable[[int], float]:
def load_schedule_fn(scedule_config: ScheduleConfig) -> Callable[[int], float]:
kwargs = scedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type")
if schedule_type == "cosine":