fix(trainer): 修复参数传递问题和检查点保存问题
This commit is contained in:
parent
c98b175cd5
commit
c934210066
|
|
@ -93,4 +93,13 @@ class TrainConfig:
|
||||||
kwargs: dict = field(
|
kwargs: dict = field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
metadata={"help": "Other arguments."}
|
metadata={"help": "Other arguments."}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.validate()
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
required_fields = ["model", "strategy", "dataset", "optimizer", "scheduler"]
|
||||||
|
for field_name in required_fields:
|
||||||
|
if getattr(self, field_name) is None:
|
||||||
|
raise ValueError(f"{field_name} is required.")
|
||||||
|
|
|
||||||
|
|
@ -94,18 +94,17 @@ class CheckpointCallback(TrainCallback):
|
||||||
def __init__(self, interval: int, save_dir: str):
|
def __init__(self, interval: int, save_dir: str):
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.checkpoint = None
|
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
def _save_checkpoint(self, context: 'TrainContext'):
|
def _save_checkpoint(self, context: 'TrainContext'):
|
||||||
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}")
|
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}")
|
||||||
self.checkpoint = Checkpoint(
|
context.checkpoint = Checkpoint(
|
||||||
context.optimizer.state_dict(),
|
context.optimizer.state_dict(),
|
||||||
context.scheduler.state_dict(),
|
context.scheduler.state_dict(),
|
||||||
context.epoch,
|
context.epoch,
|
||||||
context.iteration
|
context.iteration
|
||||||
)
|
)
|
||||||
self.checkpoint.save(save_path)
|
context.checkpoint.save(save_path)
|
||||||
self.last_ckpt_iter = context.iteration
|
self.last_ckpt_iter = context.iteration
|
||||||
|
|
||||||
def on_batch_end(self, context: 'TrainContext'):
|
def on_batch_end(self, context: 'TrainContext'):
|
||||||
|
|
|
||||||
|
|
@ -11,10 +11,18 @@ def test_different_batch_sizes(base_test_env, random_dataset):
|
||||||
batch_sizes = [1, 2, 4, 8]
|
batch_sizes = [1, 2, 4, 8]
|
||||||
|
|
||||||
for batch_size in batch_sizes:
|
for batch_size in batch_sizes:
|
||||||
|
schedule_config = CosineScheduleConfig(
|
||||||
|
warmup_steps=10,
|
||||||
|
total_steps=20
|
||||||
|
)
|
||||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
||||||
|
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
|
strategy="seq",
|
||||||
|
model=base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
checkpoint_dir=base_test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
|
@ -67,10 +75,18 @@ def test_memory_efficient_training(base_test_env, random_dataset):
|
||||||
]
|
]
|
||||||
|
|
||||||
for config in small_batch_configs:
|
for config in small_batch_configs:
|
||||||
|
schedule_config = CosineScheduleConfig(
|
||||||
|
warmup_steps=10,
|
||||||
|
total_steps=20
|
||||||
|
)
|
||||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
||||||
|
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
|
strategy="seq",
|
||||||
|
model=base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
checkpoint_dir=base_test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_size=config["batch_size"],
|
batch_size=config["batch_size"],
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue