From 6089a12cefaee260c92e07b7c08f50e755a8b22e Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 28 Feb 2026 19:01:16 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E4=BC=A0=E9=80=92=E9=97=AE=E9=A2=98=E5=B9=B6=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E5=8D=95=E5=85=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/train_context.py | 2 +- tests/trainer/test_callbacks.py | 8 ++++---- tests/trainer/test_early_stopping.py | 9 +++++---- tests/trainer/test_trainer.py | 27 +++++++++++++++------------ 4 files changed, 25 insertions(+), 21 deletions(-) diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index e4e4dcb..e1095f9 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -47,7 +47,7 @@ class TrainContextBuilder: fn = self.config.parallel_wrapper self._context.model = fn(self._context.model) - self._context.optimizer = self.config.optimizer_fn(self._context.model.parameters()) + self._context.optimizer = self.config.optimizer_fn(self._context.model) self._context.scheduler = self.config.scheduler_fn(self._context.optimizer) def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index d4c1581..7855d7d 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -10,15 +10,15 @@ def test_callback_integration(base_test_env, random_dataset): total_steps=20 ) - optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) - scheduler = SchedulerFactory.load(optimizer, schedule_config) + optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) + scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) train_config = TrainConfig( model=base_test_env["model"], strategy='seq', dataset=random_dataset, - optimizer=optimizer, - scheduler=scheduler, + optimizer_fn=optimizer_fn, + scheduler_fn=scheduler_fn, checkpoint_dir=base_test_env["test_dir"], n_epoch=1, batch_size=2, diff --git a/tests/trainer/test_early_stopping.py b/tests/trainer/test_early_stopping.py index 2ead17d..7070bc5 100644 --- a/tests/trainer/test_early_stopping.py +++ b/tests/trainer/test_early_stopping.py @@ -9,15 +9,16 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset): """Simulate early stopping behavior""" schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) - optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) - scheduler = SchedulerFactory.load(optimizer, schedule_config) + + optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) + scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) train_config = TrainConfig( strategy="seq", - scheduler=scheduler, + optimizer_fn=optimizer_fn, + scheduler_fn=scheduler_fn, model=base_test_env["model"], dataset=early_stopping_dataset, - optimizer=optimizer, checkpoint_dir=base_test_env["test_dir"], n_epoch=2, batch_size=2, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 7470bdf..27261ce 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -15,14 +15,15 @@ def test_different_batch_sizes(base_test_env, random_dataset): warmup_steps=10, total_steps=20 ) - optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) - scheduler = SchedulerFactory.load(optimizer, schedule_config) + optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) + scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) + train_config = TrainConfig( strategy="seq", model=base_test_env["model"], dataset=random_dataset, - optimizer=optimizer, - scheduler=scheduler, + optimizer_fn=optimizer_fn, + scheduler_fn=scheduler_fn, checkpoint_dir=base_test_env["test_dir"], n_epoch=1, batch_size=batch_size, @@ -43,13 +44,14 @@ def test_gradient_accumulation(base_test_env, random_dataset): warmup_steps=10, total_steps=20 ) - optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) - scheduler = SchedulerFactory.load(optimizer, schedule_config) + optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) + scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) + train_config = TrainConfig( strategy="seq", model=base_test_env["model"], - optimizer=optimizer, - scheduler=scheduler, + optimizer_fn=optimizer_fn, + scheduler_fn=scheduler_fn, dataset=random_dataset, checkpoint_dir=base_test_env["test_dir"], n_epoch=1, @@ -79,14 +81,15 @@ def test_memory_efficient_training(base_test_env, random_dataset): warmup_steps=10, total_steps=20 ) - optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) - scheduler = SchedulerFactory.load(optimizer, schedule_config) + optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) + scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) + train_config = TrainConfig( strategy="seq", model=base_test_env["model"], dataset=random_dataset, - optimizer=optimizer, - scheduler=scheduler, + optimizer_fn=optimizer_fn, + scheduler_fn=scheduler_fn, checkpoint_dir=base_test_env["test_dir"], n_epoch=1, batch_size=config["batch_size"],