fix(config): 修正 SGDRScheduleConfig 类名拼写错误
This commit is contained in:
parent
e051005334
commit
bc5ef72001
|
|
@ -1,6 +1,6 @@
|
||||||
from khaosz.config.model_config import TransformerConfig
|
from khaosz.config.model_config import TransformerConfig
|
||||||
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
|
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
|
||||||
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SgdrScheduleConfig
|
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
|
||||||
from khaosz.config.train_config import TrainConfig
|
from khaosz.config.train_config import TrainConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -14,5 +14,5 @@ __all__ = [
|
||||||
|
|
||||||
"ScheduleConfig",
|
"ScheduleConfig",
|
||||||
"CosineScheduleConfig",
|
"CosineScheduleConfig",
|
||||||
"SgdrScheduleConfig",
|
"SGDRScheduleConfig",
|
||||||
]
|
]
|
||||||
|
|
@ -59,7 +59,7 @@ class CosineScheduleConfig(ScheduleConfig):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SgdrScheduleConfig(ScheduleConfig):
|
class SGDRScheduleConfig(ScheduleConfig):
|
||||||
cycle_length: int = field(
|
cycle_length: int = field(
|
||||||
default=1000,
|
default=1000,
|
||||||
metadata={"help": "Length of the first cycle in steps."}
|
metadata={"help": "Length of the first cycle in steps."}
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ def test_schedule_factory_random_configs():
|
||||||
total_steps=np.random.randint(1000, 5000),
|
total_steps=np.random.randint(1000, 5000),
|
||||||
min_rate=np.random.uniform(0.01, 0.1)
|
min_rate=np.random.uniform(0.01, 0.1)
|
||||||
),
|
),
|
||||||
SgdrScheduleConfig(
|
SGDRScheduleConfig(
|
||||||
warmup_steps=np.random.randint(50, 200),
|
warmup_steps=np.random.randint(50, 200),
|
||||||
cycle_length=np.random.randint(500, 2000),
|
cycle_length=np.random.randint(500, 2000),
|
||||||
t_mult=np.random.randint(1, 3),
|
t_mult=np.random.randint(1, 3),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue