feat(tests): 改进测试环境配置与设备管理

This commit is contained in:
ViperEkura 2025-10-04 12:12:42 +08:00
parent 465a1a9373
commit e7d29ca2d5
7 changed files with 13 additions and 16 deletions

View File

@ -16,8 +16,9 @@ from khaosz.trainer.data_util import *
matplotlib.use("Agg") matplotlib.use("Agg")
@pytest.fixture @pytest.fixture
def base_test_env(): def base_test_env(request: pytest.FixtureRequest):
test_dir = tempfile.mkdtemp() func_name = request.function.__name__
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
config_path = os.path.join(test_dir, "config.json") config_path = os.path.join(test_dir, "config.json")
n_dim_choices = [8, 16, 32] n_dim_choices = [8, 16, 32]
@ -41,12 +42,13 @@ def base_test_env():
with open(config_path, 'w') as f: with open(config_path, 'w') as f:
json.dump(config, f) json.dump(config, f)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformer_config = TransformerConfig().load(config_path) transformer_config = TransformerConfig().load(config_path)
model = Transformer(transformer_config) model = Transformer(transformer_config).to(device=device)
tokenizer = BpeTokenizer() tokenizer = BpeTokenizer()
yield { yield {
"device": device,
"test_dir": test_dir, "test_dir": test_dir,
"config_path": config_path, "config_path": config_path,
"transformer_config": transformer_config, "transformer_config": transformer_config,

View File

@ -37,7 +37,7 @@ def test_callback_integration(base_test_env, random_dataset):
def on_epoch_end(self, trainer, **kwargs): def on_epoch_end(self, trainer, **kwargs):
callback_calls.append('on_epoch_end') callback_calls.append('on_epoch_end')
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq") train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
model_parameter = ModelParameter( model_parameter = ModelParameter(
base_test_env["model"], base_test_env["model"],
base_test_env["tokenizer"], base_test_env["tokenizer"],

View File

@ -34,7 +34,6 @@ def test_dataset_loader_random_paths(base_test_env):
train_type="seq", train_type="seq",
load_path=pkl_paths, load_path=pkl_paths,
max_len=64, max_len=64,
device="cpu"
) )
assert loaded_dataset is not None assert loaded_dataset is not None
assert len(loaded_dataset) > 0 assert len(loaded_dataset) > 0
@ -62,7 +61,6 @@ def test_dpo_strategy_with_random_data(base_test_env):
train_type="dpo", train_type="dpo",
load_path=pkl_path, load_path=pkl_path,
max_len=64, max_len=64,
device="cpu"
) )
assert dpo_dataset is not None assert dpo_dataset is not None

View File

@ -41,7 +41,7 @@ def test_early_stopping_simulation(base_test_env):
random_seed=42 random_seed=42
) )
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq") train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
model_parameter = ModelParameter( model_parameter = ModelParameter(
base_test_env["model"], base_test_env["model"],
base_test_env["tokenizer"], base_test_env["tokenizer"],

View File

@ -10,8 +10,9 @@ from khaosz.core.generator import EmbeddingEncoderCore, GeneratorCore
from tokenizers import pre_tokenizers from tokenizers import pre_tokenizers
@pytest.fixture @pytest.fixture
def test_env(): def test_env(request: pytest.FixtureRequest):
test_dir = tempfile.mkdtemp() func_name = request.function.__name__
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
config_path = os.path.join(test_dir, "config.json") config_path = os.path.join(test_dir, "config.json")
tokenizer_path = os.path.join(test_dir, "tokenizer.json") tokenizer_path = os.path.join(test_dir, "tokenizer.json")
model_path = os.path.join(test_dir, "model.safetensors") model_path = os.path.join(test_dir, "model.safetensors")

View File

@ -47,12 +47,7 @@ def test_gradient_accumulation(base_test_env, random_dataset):
warmup_steps=10, warmup_steps=10,
total_steps=20 total_steps=20
) )
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
train_config.strategy = StrategyFactory.load(
base_test_env["model"],
"seq"
)
model_parameter = ModelParameter( model_parameter = ModelParameter(
base_test_env["model"], base_test_env["model"],
base_test_env["tokenizer"], base_test_env["tokenizer"],

View File

@ -28,6 +28,7 @@ def test_multi_turn_training(base_test_env, multi_turn_dataset):
train_config.strategy = StrategyFactory.load( train_config.strategy = StrategyFactory.load(
base_test_env["model"], base_test_env["model"],
"sft", "sft",
base_test_env["device"],
bos_token_id=2, bos_token_id=2,
eos_token_id=3, eos_token_id=3,
user_token_id=1, user_token_id=1,