fix: 修复参数传递问题

This commit is contained in:
ViperEkura 2026-03-30 22:22:36 +08:00
parent e1f9901384
commit 345fd2f091
6 changed files with 13 additions and 5 deletions

View File

@ -81,6 +81,7 @@ def only_on_rank(rank, sync=False):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
ret_args = None
if get_rank() == rank:
ret_args = func(*args, **kwargs)

View File

@ -57,6 +57,8 @@ class Trainer:
world_size=config.nprocs,
master_addr=config.master_addr,
master_port=config.master_port,
device_type=config.device_type,
device_ids=config.device_ids,
checkpoint=checkpoint
)

View File

@ -96,7 +96,7 @@ def base_test_env(request: pytest.FixtureRequest):
with open(config_path, 'w') as f:
json.dump(config, f)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda" if torch.cuda.is_available() else "cpu"
transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config).to(device=device)
tokenizer = BpeTokenizer()

View File

@ -25,7 +25,8 @@ def test_callback_integration(base_test_env, random_dataset):
checkpoint_interval=3,
accumulation_steps=1,
max_grad_norm=1.0,
random_seed=42
random_seed=42,
device_type=base_test_env["device"]
)

View File

@ -25,6 +25,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
checkpoint_interval=1,
accumulation_steps=2,
random_seed=np.random.randint(1e4),
device_type=base_test_env["device"]
)
trainer = Trainer(train_config)

View File

@ -30,7 +30,8 @@ def test_different_batch_sizes(base_test_env, random_dataset):
checkpoint_interval=5,
accumulation_steps=1,
max_grad_norm=1.0,
random_seed=np.random.randint(1000)
random_seed=np.random.randint(1000),
device_type=base_test_env["device"]
)
assert train_config.batch_size == batch_size
@ -59,7 +60,8 @@ def test_gradient_accumulation(base_test_env, random_dataset):
checkpoint_interval=10,
accumulation_steps=accumulation_steps,
max_grad_norm=1.0,
random_seed=42
random_seed=42,
device_type=base_test_env["device"]
)
trainer = Trainer(train_config)
@ -96,7 +98,8 @@ def test_memory_efficient_training(base_test_env, random_dataset):
checkpoint_interval=5,
accumulation_steps=config["accumulation_steps"],
max_grad_norm=1.0,
random_seed=42
random_seed=42,
device_type=base_test_env["device"]
)
assert train_config.accumulation_steps == config["accumulation_steps"]