From c934210066d032a966ee516cd2db893f53b016b4 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 8 Dec 2025 13:28:11 +0800 Subject: [PATCH] =?UTF-8?q?fix(trainer):=20=E4=BF=AE=E5=A4=8D=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E4=BC=A0=E9=80=92=E9=97=AE=E9=A2=98=E5=92=8C=E6=A3=80?= =?UTF-8?q?=E6=9F=A5=E7=82=B9=E4=BF=9D=E5=AD=98=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/config/train_config.py | 11 ++++++++++- khaosz/trainer/train_callback.py | 5 ++--- tests/test_train_config.py | 16 ++++++++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/khaosz/config/train_config.py b/khaosz/config/train_config.py index f089ed4..4d899c6 100644 --- a/khaosz/config/train_config.py +++ b/khaosz/config/train_config.py @@ -93,4 +93,13 @@ class TrainConfig: kwargs: dict = field( default_factory=dict, metadata={"help": "Other arguments."} - ) \ No newline at end of file + ) + + def __post_init__(self): + self.validate() + + def validate(self): + required_fields = ["model", "strategy", "dataset", "optimizer", "scheduler"] + for field_name in required_fields: + if getattr(self, field_name) is None: + raise ValueError(f"{field_name} is required.") diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 90ef9f6..dd31629 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -94,18 +94,17 @@ class CheckpointCallback(TrainCallback): def __init__(self, interval: int, save_dir: str): self.interval = interval self.save_dir = save_dir - self.checkpoint = None self.last_ckpt_iter = 0 def _save_checkpoint(self, context: 'TrainContext'): save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}") - self.checkpoint = Checkpoint( + context.checkpoint = Checkpoint( context.optimizer.state_dict(), context.scheduler.state_dict(), context.epoch, context.iteration ) - self.checkpoint.save(save_path) + context.checkpoint.save(save_path) self.last_ckpt_iter = context.iteration def on_batch_end(self, context: 'TrainContext'): diff --git a/tests/test_train_config.py b/tests/test_train_config.py index 38d1778..7470bdf 100644 --- a/tests/test_train_config.py +++ b/tests/test_train_config.py @@ -11,10 +11,18 @@ def test_different_batch_sizes(base_test_env, random_dataset): batch_sizes = [1, 2, 4, 8] for batch_size in batch_sizes: + schedule_config = CosineScheduleConfig( + warmup_steps=10, + total_steps=20 + ) optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) + scheduler = SchedulerFactory.load(optimizer, schedule_config) train_config = TrainConfig( + strategy="seq", + model=base_test_env["model"], dataset=random_dataset, optimizer=optimizer, + scheduler=scheduler, checkpoint_dir=base_test_env["test_dir"], n_epoch=1, batch_size=batch_size, @@ -67,10 +75,18 @@ def test_memory_efficient_training(base_test_env, random_dataset): ] for config in small_batch_configs: + schedule_config = CosineScheduleConfig( + warmup_steps=10, + total_steps=20 + ) optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) + scheduler = SchedulerFactory.load(optimizer, schedule_config) train_config = TrainConfig( + strategy="seq", + model=base_test_env["model"], dataset=random_dataset, optimizer=optimizer, + scheduler=scheduler, checkpoint_dir=base_test_env["test_dir"], n_epoch=1, batch_size=config["batch_size"],