test(sampler): 删除冗余的训练恢复测试用例
This commit is contained in:
parent
85aeec9e55
commit
6e1a497c04
|
|
@ -1,5 +1,3 @@
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
from khaosz.core import *
|
from khaosz.core import *
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
from khaosz.trainer.data_util import *
|
from khaosz.trainer.data_util import *
|
||||||
|
|
@ -57,55 +55,7 @@ def test_sampler_state_persistence(random_dataset):
|
||||||
|
|
||||||
assert indices2 == indices3
|
assert indices2 == indices3
|
||||||
|
|
||||||
def test_training_resume_with_sampler(base_test_env, random_dataset):
|
def test_sampler_across_epochs(random_dataset):
|
||||||
"""Test that training can resume correctly with sampler state"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
dataset = random_dataset
|
|
||||||
|
|
||||||
# Initial training config
|
|
||||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
|
||||||
train_config = TrainConfig(
|
|
||||||
dataset=dataset,
|
|
||||||
optimizer=optimizer,
|
|
||||||
checkpoint_dir=test_dir,
|
|
||||||
n_epoch=1,
|
|
||||||
batch_size=2,
|
|
||||||
checkpoint_interval=5,
|
|
||||||
accumulation_steps=1,
|
|
||||||
max_grad_norm=1.0,
|
|
||||||
random_seed=42
|
|
||||||
)
|
|
||||||
|
|
||||||
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq")
|
|
||||||
model_parameter = ModelParameter(
|
|
||||||
base_test_env["model"],
|
|
||||||
base_test_env["tokenizer"],
|
|
||||||
base_test_env["transformer_config"]
|
|
||||||
)
|
|
||||||
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
|
||||||
|
|
||||||
# First training run - stop after a few steps
|
|
||||||
trainer = Trainer(model_parameter, train_config, schedule_config)
|
|
||||||
try:
|
|
||||||
# Run for a few steps then interrupt
|
|
||||||
for i, _ in enumerate(trainer.train()):
|
|
||||||
if i >= 3: # Run for 3 steps then stop
|
|
||||||
break
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Load checkpoint
|
|
||||||
checkpoint_path = os.path.join(test_dir, "iter_3")
|
|
||||||
checkpoint = Checkpoint().load(checkpoint_path)
|
|
||||||
|
|
||||||
# Resume training
|
|
||||||
trainer = Trainer(model_parameter, train_config, schedule_config)
|
|
||||||
resumed_checkpoint = trainer.train(checkpoint)
|
|
||||||
|
|
||||||
# Check that training resumed from correct point
|
|
||||||
assert resumed_checkpoint.sampler_state['current_iter'] > 3
|
|
||||||
|
|
||||||
def test_sampler_across_epochs(base_test_env, random_dataset):
|
|
||||||
"""Test sampler behavior across multiple epochs"""
|
"""Test sampler behavior across multiple epochs"""
|
||||||
dataset = random_dataset
|
dataset = random_dataset
|
||||||
n = len(dataset)
|
n = len(dataset)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue