From bc5ef72001f232f85dbe418bf8bd876431fa8a4a Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 20 Oct 2025 18:21:46 +0800 Subject: [PATCH] =?UTF-8?q?fix(config):=20=E4=BF=AE=E6=AD=A3=20SGDRSchedul?= =?UTF-8?q?eConfig=20=E7=B1=BB=E5=90=8D=E6=8B=BC=E5=86=99=E9=94=99?= =?UTF-8?q?=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/config/__init__.py | 4 ++-- khaosz/config/schedule_config.py | 2 +- tests/test_train_strategy.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/khaosz/config/__init__.py b/khaosz/config/__init__.py index eae4607..ba39ad6 100644 --- a/khaosz/config/__init__.py +++ b/khaosz/config/__init__.py @@ -1,6 +1,6 @@ from khaosz.config.model_config import TransformerConfig from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader -from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SgdrScheduleConfig +from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig from khaosz.config.train_config import TrainConfig @@ -14,5 +14,5 @@ __all__ = [ "ScheduleConfig", "CosineScheduleConfig", - "SgdrScheduleConfig", + "SGDRScheduleConfig", ] \ No newline at end of file diff --git a/khaosz/config/schedule_config.py b/khaosz/config/schedule_config.py index 82b99c6..3724ab8 100644 --- a/khaosz/config/schedule_config.py +++ b/khaosz/config/schedule_config.py @@ -59,7 +59,7 @@ class CosineScheduleConfig(ScheduleConfig): @dataclass -class SgdrScheduleConfig(ScheduleConfig): +class SGDRScheduleConfig(ScheduleConfig): cycle_length: int = field( default=1000, metadata={"help": "Length of the first cycle in steps."} diff --git a/tests/test_train_strategy.py b/tests/test_train_strategy.py index 4f4fadf..da8a8c0 100644 --- a/tests/test_train_strategy.py +++ b/tests/test_train_strategy.py @@ -55,7 +55,7 @@ def test_schedule_factory_random_configs(): total_steps=np.random.randint(1000, 5000), min_rate=np.random.uniform(0.01, 0.1) ), - SgdrScheduleConfig( + SGDRScheduleConfig( warmup_steps=np.random.randint(50, 200), cycle_length=np.random.randint(500, 2000), t_mult=np.random.randint(1, 3),