feat(khaosz/trainer): 改进调度器配置验证和加载逻辑
This commit is contained in:
parent
bdda1cc35a
commit
6d5176a11c
|
|
@ -154,9 +154,9 @@ class SchedulerCallback(TrainerCallback):
|
|||
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
||||
self.current_iter = len(checkpoint.loss_list)
|
||||
|
||||
self.schedule_config.validate()
|
||||
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
||||
**self.schedule_config.get_kwargs()
|
||||
self.schedule_config
|
||||
)
|
||||
|
||||
self.scheduler = LambdaLR(
|
||||
|
|
|
|||
|
|
@ -388,22 +388,8 @@ class SchedulerFactory:
|
|||
return cosine_schedule
|
||||
|
||||
@staticmethod
|
||||
def create_schedule(config: ScheduleConfig) -> Callable[[int], float]:
|
||||
"""
|
||||
Create schedule from configuration.
|
||||
|
||||
Args:
|
||||
config: Schedule configuration instance
|
||||
|
||||
Returns:
|
||||
Schedule function
|
||||
"""
|
||||
config.validate()
|
||||
kwargs = config.get_kwargs()
|
||||
return SchedulerFactory.load_schedule_fn(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def load_schedule_fn(**kwargs) -> Callable[[int], float]:
|
||||
def load_schedule_fn(scedule_config: ScheduleConfig) -> Callable[[int], float]:
|
||||
kwargs = scedule_config.get_kwargs()
|
||||
schedule_type = kwargs.pop("schedule_type")
|
||||
|
||||
if schedule_type == "cosine":
|
||||
|
|
|
|||
Loading…
Reference in New Issue