fix(trainer): 修复参数传递问题和检查点保存问题

This commit is contained in:
ViperEkura 2025-12-08 13:28:11 +08:00
parent c98b175cd5
commit c934210066
3 changed files with 28 additions and 4 deletions

View File

@ -93,4 +93,13 @@ class TrainConfig:
kwargs: dict = field( kwargs: dict = field(
default_factory=dict, default_factory=dict,
metadata={"help": "Other arguments."} metadata={"help": "Other arguments."}
) )
def __post_init__(self):
self.validate()
def validate(self):
required_fields = ["model", "strategy", "dataset", "optimizer", "scheduler"]
for field_name in required_fields:
if getattr(self, field_name) is None:
raise ValueError(f"{field_name} is required.")

View File

@ -94,18 +94,17 @@ class CheckpointCallback(TrainCallback):
def __init__(self, interval: int, save_dir: str): def __init__(self, interval: int, save_dir: str):
self.interval = interval self.interval = interval
self.save_dir = save_dir self.save_dir = save_dir
self.checkpoint = None
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
def _save_checkpoint(self, context: 'TrainContext'): def _save_checkpoint(self, context: 'TrainContext'):
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}") save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}")
self.checkpoint = Checkpoint( context.checkpoint = Checkpoint(
context.optimizer.state_dict(), context.optimizer.state_dict(),
context.scheduler.state_dict(), context.scheduler.state_dict(),
context.epoch, context.epoch,
context.iteration context.iteration
) )
self.checkpoint.save(save_path) context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration self.last_ckpt_iter = context.iteration
def on_batch_end(self, context: 'TrainContext'): def on_batch_end(self, context: 'TrainContext'):

View File

@ -11,10 +11,18 @@ def test_different_batch_sizes(base_test_env, random_dataset):
batch_sizes = [1, 2, 4, 8] batch_sizes = [1, 2, 4, 8]
for batch_size in batch_sizes: for batch_size in batch_sizes:
schedule_config = CosineScheduleConfig(
warmup_steps=10,
total_steps=20
)
optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
scheduler = SchedulerFactory.load(optimizer, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
strategy="seq",
model=base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler,
checkpoint_dir=base_test_env["test_dir"], checkpoint_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=batch_size, batch_size=batch_size,
@ -67,10 +75,18 @@ def test_memory_efficient_training(base_test_env, random_dataset):
] ]
for config in small_batch_configs: for config in small_batch_configs:
schedule_config = CosineScheduleConfig(
warmup_steps=10,
total_steps=20
)
optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
scheduler = SchedulerFactory.load(optimizer, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
strategy="seq",
model=base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler,
checkpoint_dir=base_test_env["test_dir"], checkpoint_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=config["batch_size"], batch_size=config["batch_size"],