fix: 修复参数传递问题并更新测试单元
This commit is contained in:
parent
b17cc6a6fb
commit
6089a12cef
|
|
@ -47,7 +47,7 @@ class TrainContextBuilder:
|
|||
fn = self.config.parallel_wrapper
|
||||
self._context.model = fn(self._context.model)
|
||||
|
||||
self._context.optimizer = self.config.optimizer_fn(self._context.model.parameters())
|
||||
self._context.optimizer = self.config.optimizer_fn(self._context.model)
|
||||
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
|
||||
|
||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
|
|
|
|||
|
|
@ -10,15 +10,15 @@ def test_callback_integration(base_test_env, random_dataset):
|
|||
total_steps=20
|
||||
)
|
||||
|
||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
||||
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||
|
||||
train_config = TrainConfig(
|
||||
model=base_test_env["model"],
|
||||
strategy='seq',
|
||||
dataset=random_dataset,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
checkpoint_dir=base_test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
batch_size=2,
|
||||
|
|
|
|||
|
|
@ -9,15 +9,16 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
|||
"""Simulate early stopping behavior"""
|
||||
|
||||
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
||||
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||
|
||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||
|
||||
train_config = TrainConfig(
|
||||
strategy="seq",
|
||||
scheduler=scheduler,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
model=base_test_env["model"],
|
||||
dataset=early_stopping_dataset,
|
||||
optimizer=optimizer,
|
||||
checkpoint_dir=base_test_env["test_dir"],
|
||||
n_epoch=2,
|
||||
batch_size=2,
|
||||
|
|
|
|||
|
|
@ -15,14 +15,15 @@ def test_different_batch_sizes(base_test_env, random_dataset):
|
|||
warmup_steps=10,
|
||||
total_steps=20
|
||||
)
|
||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
||||
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||
|
||||
train_config = TrainConfig(
|
||||
strategy="seq",
|
||||
model=base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
checkpoint_dir=base_test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
batch_size=batch_size,
|
||||
|
|
@ -43,13 +44,14 @@ def test_gradient_accumulation(base_test_env, random_dataset):
|
|||
warmup_steps=10,
|
||||
total_steps=20
|
||||
)
|
||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
||||
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||
|
||||
train_config = TrainConfig(
|
||||
strategy="seq",
|
||||
model=base_test_env["model"],
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
dataset=random_dataset,
|
||||
checkpoint_dir=base_test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
|
|
@ -79,14 +81,15 @@ def test_memory_efficient_training(base_test_env, random_dataset):
|
|||
warmup_steps=10,
|
||||
total_steps=20
|
||||
)
|
||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
||||
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||
|
||||
train_config = TrainConfig(
|
||||
strategy="seq",
|
||||
model=base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
checkpoint_dir=base_test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
batch_size=config["batch_size"],
|
||||
|
|
|
|||
Loading…
Reference in New Issue