fix: 修复参数传递问题
This commit is contained in:
parent
e1f9901384
commit
345fd2f091
|
|
@ -81,6 +81,7 @@ def only_on_rank(rank, sync=False):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
|
ret_args = None
|
||||||
if get_rank() == rank:
|
if get_rank() == rank:
|
||||||
ret_args = func(*args, **kwargs)
|
ret_args = func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,8 @@ class Trainer:
|
||||||
world_size=config.nprocs,
|
world_size=config.nprocs,
|
||||||
master_addr=config.master_addr,
|
master_addr=config.master_addr,
|
||||||
master_port=config.master_port,
|
master_port=config.master_port,
|
||||||
|
device_type=config.device_type,
|
||||||
|
device_ids=config.device_ids,
|
||||||
checkpoint=checkpoint
|
checkpoint=checkpoint
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -96,7 +96,7 @@ def base_test_env(request: pytest.FixtureRequest):
|
||||||
|
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, 'w') as f:
|
||||||
json.dump(config, f)
|
json.dump(config, f)
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
transformer_config = ModelConfig().load(config_path)
|
transformer_config = ModelConfig().load(config_path)
|
||||||
model = Transformer(transformer_config).to(device=device)
|
model = Transformer(transformer_config).to(device=device)
|
||||||
tokenizer = BpeTokenizer()
|
tokenizer = BpeTokenizer()
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,8 @@ def test_callback_integration(base_test_env, random_dataset):
|
||||||
checkpoint_interval=3,
|
checkpoint_interval=3,
|
||||||
accumulation_steps=1,
|
accumulation_steps=1,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=42
|
random_seed=42,
|
||||||
|
device_type=base_test_env["device"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
checkpoint_interval=1,
|
checkpoint_interval=1,
|
||||||
accumulation_steps=2,
|
accumulation_steps=2,
|
||||||
random_seed=np.random.randint(1e4),
|
random_seed=np.random.randint(1e4),
|
||||||
|
device_type=base_test_env["device"]
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,8 @@ def test_different_batch_sizes(base_test_env, random_dataset):
|
||||||
checkpoint_interval=5,
|
checkpoint_interval=5,
|
||||||
accumulation_steps=1,
|
accumulation_steps=1,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=np.random.randint(1000)
|
random_seed=np.random.randint(1000),
|
||||||
|
device_type=base_test_env["device"]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert train_config.batch_size == batch_size
|
assert train_config.batch_size == batch_size
|
||||||
|
|
@ -59,7 +60,8 @@ def test_gradient_accumulation(base_test_env, random_dataset):
|
||||||
checkpoint_interval=10,
|
checkpoint_interval=10,
|
||||||
accumulation_steps=accumulation_steps,
|
accumulation_steps=accumulation_steps,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=42
|
random_seed=42,
|
||||||
|
device_type=base_test_env["device"]
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
|
|
@ -96,7 +98,8 @@ def test_memory_efficient_training(base_test_env, random_dataset):
|
||||||
checkpoint_interval=5,
|
checkpoint_interval=5,
|
||||||
accumulation_steps=config["accumulation_steps"],
|
accumulation_steps=config["accumulation_steps"],
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=42
|
random_seed=42,
|
||||||
|
device_type=base_test_env["device"]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert train_config.accumulation_steps == config["accumulation_steps"]
|
assert train_config.accumulation_steps == config["accumulation_steps"]
|
||||||
Loading…
Reference in New Issue