feat(tests): 改进测试环境配置与设备管理
This commit is contained in:
parent
465a1a9373
commit
e7d29ca2d5
|
|
@ -16,8 +16,9 @@ from khaosz.trainer.data_util import *
|
|||
matplotlib.use("Agg")
|
||||
|
||||
@pytest.fixture
|
||||
def base_test_env():
|
||||
test_dir = tempfile.mkdtemp()
|
||||
def base_test_env(request: pytest.FixtureRequest):
|
||||
func_name = request.function.__name__
|
||||
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
|
||||
config_path = os.path.join(test_dir, "config.json")
|
||||
|
||||
n_dim_choices = [8, 16, 32]
|
||||
|
|
@ -41,12 +42,13 @@ def base_test_env():
|
|||
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(config, f)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
transformer_config = TransformerConfig().load(config_path)
|
||||
model = Transformer(transformer_config)
|
||||
model = Transformer(transformer_config).to(device=device)
|
||||
tokenizer = BpeTokenizer()
|
||||
|
||||
yield {
|
||||
"device": device,
|
||||
"test_dir": test_dir,
|
||||
"config_path": config_path,
|
||||
"transformer_config": transformer_config,
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ def test_callback_integration(base_test_env, random_dataset):
|
|||
def on_epoch_end(self, trainer, **kwargs):
|
||||
callback_calls.append('on_epoch_end')
|
||||
|
||||
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq")
|
||||
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
|
||||
model_parameter = ModelParameter(
|
||||
base_test_env["model"],
|
||||
base_test_env["tokenizer"],
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ def test_dataset_loader_random_paths(base_test_env):
|
|||
train_type="seq",
|
||||
load_path=pkl_paths,
|
||||
max_len=64,
|
||||
device="cpu"
|
||||
)
|
||||
assert loaded_dataset is not None
|
||||
assert len(loaded_dataset) > 0
|
||||
|
|
@ -62,7 +61,6 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
|||
train_type="dpo",
|
||||
load_path=pkl_path,
|
||||
max_len=64,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
assert dpo_dataset is not None
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ def test_early_stopping_simulation(base_test_env):
|
|||
random_seed=42
|
||||
)
|
||||
|
||||
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq")
|
||||
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
|
||||
model_parameter = ModelParameter(
|
||||
base_test_env["model"],
|
||||
base_test_env["tokenizer"],
|
||||
|
|
|
|||
|
|
@ -10,8 +10,9 @@ from khaosz.core.generator import EmbeddingEncoderCore, GeneratorCore
|
|||
from tokenizers import pre_tokenizers
|
||||
|
||||
@pytest.fixture
|
||||
def test_env():
|
||||
test_dir = tempfile.mkdtemp()
|
||||
def test_env(request: pytest.FixtureRequest):
|
||||
func_name = request.function.__name__
|
||||
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
|
||||
config_path = os.path.join(test_dir, "config.json")
|
||||
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
|
||||
model_path = os.path.join(test_dir, "model.safetensors")
|
||||
|
|
|
|||
|
|
@ -47,12 +47,7 @@ def test_gradient_accumulation(base_test_env, random_dataset):
|
|||
warmup_steps=10,
|
||||
total_steps=20
|
||||
)
|
||||
|
||||
train_config.strategy = StrategyFactory.load(
|
||||
base_test_env["model"],
|
||||
"seq"
|
||||
)
|
||||
|
||||
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
|
||||
model_parameter = ModelParameter(
|
||||
base_test_env["model"],
|
||||
base_test_env["tokenizer"],
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ def test_multi_turn_training(base_test_env, multi_turn_dataset):
|
|||
train_config.strategy = StrategyFactory.load(
|
||||
base_test_env["model"],
|
||||
"sft",
|
||||
base_test_env["device"],
|
||||
bos_token_id=2,
|
||||
eos_token_id=3,
|
||||
user_token_id=1,
|
||||
|
|
|
|||
Loading…
Reference in New Issue