fix(config): 修正 SGDRScheduleConfig 类名拼写错误

This commit is contained in:
ViperEkura 2025-10-20 18:21:46 +08:00
parent e051005334
commit bc5ef72001
3 changed files with 4 additions and 4 deletions

View File

@ -1,6 +1,6 @@
from khaosz.config.model_config import TransformerConfig from khaosz.config.model_config import TransformerConfig
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SgdrScheduleConfig from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
from khaosz.config.train_config import TrainConfig from khaosz.config.train_config import TrainConfig
@ -14,5 +14,5 @@ __all__ = [
"ScheduleConfig", "ScheduleConfig",
"CosineScheduleConfig", "CosineScheduleConfig",
"SgdrScheduleConfig", "SGDRScheduleConfig",
] ]

View File

@ -59,7 +59,7 @@ class CosineScheduleConfig(ScheduleConfig):
@dataclass @dataclass
class SgdrScheduleConfig(ScheduleConfig): class SGDRScheduleConfig(ScheduleConfig):
cycle_length: int = field( cycle_length: int = field(
default=1000, default=1000,
metadata={"help": "Length of the first cycle in steps."} metadata={"help": "Length of the first cycle in steps."}

View File

@ -55,7 +55,7 @@ def test_schedule_factory_random_configs():
total_steps=np.random.randint(1000, 5000), total_steps=np.random.randint(1000, 5000),
min_rate=np.random.uniform(0.01, 0.1) min_rate=np.random.uniform(0.01, 0.1)
), ),
SgdrScheduleConfig( SGDRScheduleConfig(
warmup_steps=np.random.randint(50, 200), warmup_steps=np.random.randint(50, 200),
cycle_length=np.random.randint(500, 2000), cycle_length=np.random.randint(500, 2000),
t_mult=np.random.randint(1, 3), t_mult=np.random.randint(1, 3),