fix: 修复参数传递问题并更新测试单元

This commit is contained in:
ViperEkura 2026-02-28 19:01:16 +08:00
parent b17cc6a6fb
commit 6089a12cef
4 changed files with 25 additions and 21 deletions

View File

@ -47,7 +47,7 @@ class TrainContextBuilder:
fn = self.config.parallel_wrapper fn = self.config.parallel_wrapper
self._context.model = fn(self._context.model) self._context.model = fn(self._context.model)
self._context.optimizer = self.config.optimizer_fn(self._context.model.parameters()) self._context.optimizer = self.config.optimizer_fn(self._context.model)
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer) self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:

View File

@ -10,15 +10,15 @@ def test_callback_integration(base_test_env, random_dataset):
total_steps=20 total_steps=20
) )
optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler = SchedulerFactory.load(optimizer, schedule_config) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
model=base_test_env["model"], model=base_test_env["model"],
strategy='seq', strategy='seq',
dataset=random_dataset, dataset=random_dataset,
optimizer=optimizer, optimizer_fn=optimizer_fn,
scheduler=scheduler, scheduler_fn=scheduler_fn,
checkpoint_dir=base_test_env["test_dir"], checkpoint_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=2, batch_size=2,

View File

@ -9,15 +9,16 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
"""Simulate early stopping behavior""" """Simulate early stopping behavior"""
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
scheduler = SchedulerFactory.load(optimizer, schedule_config) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
strategy="seq", strategy="seq",
scheduler=scheduler, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
model=base_test_env["model"], model=base_test_env["model"],
dataset=early_stopping_dataset, dataset=early_stopping_dataset,
optimizer=optimizer,
checkpoint_dir=base_test_env["test_dir"], checkpoint_dir=base_test_env["test_dir"],
n_epoch=2, n_epoch=2,
batch_size=2, batch_size=2,

View File

@ -15,14 +15,15 @@ def test_different_batch_sizes(base_test_env, random_dataset):
warmup_steps=10, warmup_steps=10,
total_steps=20 total_steps=20
) )
optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler = SchedulerFactory.load(optimizer, schedule_config) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
strategy="seq", strategy="seq",
model=base_test_env["model"], model=base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
optimizer=optimizer, optimizer_fn=optimizer_fn,
scheduler=scheduler, scheduler_fn=scheduler_fn,
checkpoint_dir=base_test_env["test_dir"], checkpoint_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=batch_size, batch_size=batch_size,
@ -43,13 +44,14 @@ def test_gradient_accumulation(base_test_env, random_dataset):
warmup_steps=10, warmup_steps=10,
total_steps=20 total_steps=20
) )
optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler = SchedulerFactory.load(optimizer, schedule_config) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
strategy="seq", strategy="seq",
model=base_test_env["model"], model=base_test_env["model"],
optimizer=optimizer, optimizer_fn=optimizer_fn,
scheduler=scheduler, scheduler_fn=scheduler_fn,
dataset=random_dataset, dataset=random_dataset,
checkpoint_dir=base_test_env["test_dir"], checkpoint_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
@ -79,14 +81,15 @@ def test_memory_efficient_training(base_test_env, random_dataset):
warmup_steps=10, warmup_steps=10,
total_steps=20 total_steps=20
) )
optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler = SchedulerFactory.load(optimizer, schedule_config) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
strategy="seq", strategy="seq",
model=base_test_env["model"], model=base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
optimizer=optimizer, optimizer_fn=optimizer_fn,
scheduler=scheduler, scheduler_fn=scheduler_fn,
checkpoint_dir=base_test_env["test_dir"], checkpoint_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=config["batch_size"], batch_size=config["batch_size"],