fix: 修复参数传递问题
This commit is contained in:
parent
e1f9901384
commit
345fd2f091
|
|
@ -81,6 +81,7 @@ def only_on_rank(rank, sync=False):
|
|||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
ret_args = None
|
||||
if get_rank() == rank:
|
||||
ret_args = func(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -57,6 +57,8 @@ class Trainer:
|
|||
world_size=config.nprocs,
|
||||
master_addr=config.master_addr,
|
||||
master_port=config.master_port,
|
||||
device_type=config.device_type,
|
||||
device_ids=config.device_ids,
|
||||
checkpoint=checkpoint
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ def base_test_env(request: pytest.FixtureRequest):
|
|||
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(config, f)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
transformer_config = ModelConfig().load(config_path)
|
||||
model = Transformer(transformer_config).to(device=device)
|
||||
tokenizer = BpeTokenizer()
|
||||
|
|
|
|||
|
|
@ -25,7 +25,8 @@ def test_callback_integration(base_test_env, random_dataset):
|
|||
checkpoint_interval=3,
|
||||
accumulation_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
random_seed=42
|
||||
random_seed=42,
|
||||
device_type=base_test_env["device"]
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
|||
checkpoint_interval=1,
|
||||
accumulation_steps=2,
|
||||
random_seed=np.random.randint(1e4),
|
||||
device_type=base_test_env["device"]
|
||||
)
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,8 @@ def test_different_batch_sizes(base_test_env, random_dataset):
|
|||
checkpoint_interval=5,
|
||||
accumulation_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
random_seed=np.random.randint(1000)
|
||||
random_seed=np.random.randint(1000),
|
||||
device_type=base_test_env["device"]
|
||||
)
|
||||
|
||||
assert train_config.batch_size == batch_size
|
||||
|
|
@ -59,7 +60,8 @@ def test_gradient_accumulation(base_test_env, random_dataset):
|
|||
checkpoint_interval=10,
|
||||
accumulation_steps=accumulation_steps,
|
||||
max_grad_norm=1.0,
|
||||
random_seed=42
|
||||
random_seed=42,
|
||||
device_type=base_test_env["device"]
|
||||
)
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
|
|
@ -96,7 +98,8 @@ def test_memory_efficient_training(base_test_env, random_dataset):
|
|||
checkpoint_interval=5,
|
||||
accumulation_steps=config["accumulation_steps"],
|
||||
max_grad_norm=1.0,
|
||||
random_seed=42
|
||||
random_seed=42,
|
||||
device_type=base_test_env["device"]
|
||||
)
|
||||
|
||||
assert train_config.accumulation_steps == config["accumulation_steps"]
|
||||
Loading…
Reference in New Issue