feat(tests): 改进测试环境配置与设备管理

This commit is contained in:
ViperEkura 2025-10-04 12:12:42 +08:00
parent 465a1a9373
commit e7d29ca2d5
7 changed files with 13 additions and 16 deletions

View File

@ -16,8 +16,9 @@ from khaosz.trainer.data_util import *
matplotlib.use("Agg")
@pytest.fixture
def base_test_env():
test_dir = tempfile.mkdtemp()
def base_test_env(request: pytest.FixtureRequest):
func_name = request.function.__name__
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
config_path = os.path.join(test_dir, "config.json")
n_dim_choices = [8, 16, 32]
@ -41,12 +42,13 @@ def base_test_env():
with open(config_path, 'w') as f:
json.dump(config, f)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformer_config = TransformerConfig().load(config_path)
model = Transformer(transformer_config)
model = Transformer(transformer_config).to(device=device)
tokenizer = BpeTokenizer()
yield {
"device": device,
"test_dir": test_dir,
"config_path": config_path,
"transformer_config": transformer_config,

View File

@ -37,7 +37,7 @@ def test_callback_integration(base_test_env, random_dataset):
def on_epoch_end(self, trainer, **kwargs):
callback_calls.append('on_epoch_end')
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq")
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
model_parameter = ModelParameter(
base_test_env["model"],
base_test_env["tokenizer"],

View File

@ -34,7 +34,6 @@ def test_dataset_loader_random_paths(base_test_env):
train_type="seq",
load_path=pkl_paths,
max_len=64,
device="cpu"
)
assert loaded_dataset is not None
assert len(loaded_dataset) > 0
@ -62,7 +61,6 @@ def test_dpo_strategy_with_random_data(base_test_env):
train_type="dpo",
load_path=pkl_path,
max_len=64,
device="cpu"
)
assert dpo_dataset is not None

View File

@ -41,7 +41,7 @@ def test_early_stopping_simulation(base_test_env):
random_seed=42
)
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq")
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
model_parameter = ModelParameter(
base_test_env["model"],
base_test_env["tokenizer"],

View File

@ -10,8 +10,9 @@ from khaosz.core.generator import EmbeddingEncoderCore, GeneratorCore
from tokenizers import pre_tokenizers
@pytest.fixture
def test_env():
test_dir = tempfile.mkdtemp()
def test_env(request: pytest.FixtureRequest):
func_name = request.function.__name__
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
config_path = os.path.join(test_dir, "config.json")
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
model_path = os.path.join(test_dir, "model.safetensors")

View File

@ -47,12 +47,7 @@ def test_gradient_accumulation(base_test_env, random_dataset):
warmup_steps=10,
total_steps=20
)
train_config.strategy = StrategyFactory.load(
base_test_env["model"],
"seq"
)
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
model_parameter = ModelParameter(
base_test_env["model"],
base_test_env["tokenizer"],

View File

@ -28,6 +28,7 @@ def test_multi_turn_training(base_test_env, multi_turn_dataset):
train_config.strategy = StrategyFactory.load(
base_test_env["model"],
"sft",
base_test_env["device"],
bos_token_id=2,
eos_token_id=3,
user_token_id=1,