feat(tests): 改进测试环境配置与设备管理
This commit is contained in:
parent
465a1a9373
commit
e7d29ca2d5
|
|
@ -16,8 +16,9 @@ from khaosz.trainer.data_util import *
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def base_test_env():
|
def base_test_env(request: pytest.FixtureRequest):
|
||||||
test_dir = tempfile.mkdtemp()
|
func_name = request.function.__name__
|
||||||
|
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
|
||||||
config_path = os.path.join(test_dir, "config.json")
|
config_path = os.path.join(test_dir, "config.json")
|
||||||
|
|
||||||
n_dim_choices = [8, 16, 32]
|
n_dim_choices = [8, 16, 32]
|
||||||
|
|
@ -41,12 +42,13 @@ def base_test_env():
|
||||||
|
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, 'w') as f:
|
||||||
json.dump(config, f)
|
json.dump(config, f)
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
transformer_config = TransformerConfig().load(config_path)
|
transformer_config = TransformerConfig().load(config_path)
|
||||||
model = Transformer(transformer_config)
|
model = Transformer(transformer_config).to(device=device)
|
||||||
tokenizer = BpeTokenizer()
|
tokenizer = BpeTokenizer()
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
|
"device": device,
|
||||||
"test_dir": test_dir,
|
"test_dir": test_dir,
|
||||||
"config_path": config_path,
|
"config_path": config_path,
|
||||||
"transformer_config": transformer_config,
|
"transformer_config": transformer_config,
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ def test_callback_integration(base_test_env, random_dataset):
|
||||||
def on_epoch_end(self, trainer, **kwargs):
|
def on_epoch_end(self, trainer, **kwargs):
|
||||||
callback_calls.append('on_epoch_end')
|
callback_calls.append('on_epoch_end')
|
||||||
|
|
||||||
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq")
|
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
|
||||||
model_parameter = ModelParameter(
|
model_parameter = ModelParameter(
|
||||||
base_test_env["model"],
|
base_test_env["model"],
|
||||||
base_test_env["tokenizer"],
|
base_test_env["tokenizer"],
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,6 @@ def test_dataset_loader_random_paths(base_test_env):
|
||||||
train_type="seq",
|
train_type="seq",
|
||||||
load_path=pkl_paths,
|
load_path=pkl_paths,
|
||||||
max_len=64,
|
max_len=64,
|
||||||
device="cpu"
|
|
||||||
)
|
)
|
||||||
assert loaded_dataset is not None
|
assert loaded_dataset is not None
|
||||||
assert len(loaded_dataset) > 0
|
assert len(loaded_dataset) > 0
|
||||||
|
|
@ -62,7 +61,6 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
||||||
train_type="dpo",
|
train_type="dpo",
|
||||||
load_path=pkl_path,
|
load_path=pkl_path,
|
||||||
max_len=64,
|
max_len=64,
|
||||||
device="cpu"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert dpo_dataset is not None
|
assert dpo_dataset is not None
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ def test_early_stopping_simulation(base_test_env):
|
||||||
random_seed=42
|
random_seed=42
|
||||||
)
|
)
|
||||||
|
|
||||||
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq")
|
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
|
||||||
model_parameter = ModelParameter(
|
model_parameter = ModelParameter(
|
||||||
base_test_env["model"],
|
base_test_env["model"],
|
||||||
base_test_env["tokenizer"],
|
base_test_env["tokenizer"],
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,9 @@ from khaosz.core.generator import EmbeddingEncoderCore, GeneratorCore
|
||||||
from tokenizers import pre_tokenizers
|
from tokenizers import pre_tokenizers
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_env():
|
def test_env(request: pytest.FixtureRequest):
|
||||||
test_dir = tempfile.mkdtemp()
|
func_name = request.function.__name__
|
||||||
|
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
|
||||||
config_path = os.path.join(test_dir, "config.json")
|
config_path = os.path.join(test_dir, "config.json")
|
||||||
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
|
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
|
||||||
model_path = os.path.join(test_dir, "model.safetensors")
|
model_path = os.path.join(test_dir, "model.safetensors")
|
||||||
|
|
|
||||||
|
|
@ -47,12 +47,7 @@ def test_gradient_accumulation(base_test_env, random_dataset):
|
||||||
warmup_steps=10,
|
warmup_steps=10,
|
||||||
total_steps=20
|
total_steps=20
|
||||||
)
|
)
|
||||||
|
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
|
||||||
train_config.strategy = StrategyFactory.load(
|
|
||||||
base_test_env["model"],
|
|
||||||
"seq"
|
|
||||||
)
|
|
||||||
|
|
||||||
model_parameter = ModelParameter(
|
model_parameter = ModelParameter(
|
||||||
base_test_env["model"],
|
base_test_env["model"],
|
||||||
base_test_env["tokenizer"],
|
base_test_env["tokenizer"],
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ def test_multi_turn_training(base_test_env, multi_turn_dataset):
|
||||||
train_config.strategy = StrategyFactory.load(
|
train_config.strategy = StrategyFactory.load(
|
||||||
base_test_env["model"],
|
base_test_env["model"],
|
||||||
"sft",
|
"sft",
|
||||||
|
base_test_env["device"],
|
||||||
bos_token_id=2,
|
bos_token_id=2,
|
||||||
eos_token_id=3,
|
eos_token_id=3,
|
||||||
user_token_id=1,
|
user_token_id=1,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue