diff --git a/khaosz/parallel/setup.py b/khaosz/parallel/setup.py index 452e30f..4c81b19 100644 --- a/khaosz/parallel/setup.py +++ b/khaosz/parallel/setup.py @@ -81,6 +81,7 @@ def only_on_rank(rank, sync=False): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): + ret_args = None if get_rank() == rank: ret_args = func(*args, **kwargs) diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index a75ef36..6d2415f 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -57,6 +57,8 @@ class Trainer: world_size=config.nprocs, master_addr=config.master_addr, master_port=config.master_port, + device_type=config.device_type, + device_ids=config.device_ids, checkpoint=checkpoint ) diff --git a/tests/conftest.py b/tests/conftest.py index 032d3f7..2dacbbb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -96,7 +96,7 @@ def base_test_env(request: pytest.FixtureRequest): with open(config_path, 'w') as f: json.dump(config, f) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = "cuda" if torch.cuda.is_available() else "cpu" transformer_config = ModelConfig().load(config_path) model = Transformer(transformer_config).to(device=device) tokenizer = BpeTokenizer() diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index 7855d7d..0ae63d8 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -25,7 +25,8 @@ def test_callback_integration(base_test_env, random_dataset): checkpoint_interval=3, accumulation_steps=1, max_grad_norm=1.0, - random_seed=42 + random_seed=42, + device_type=base_test_env["device"] ) diff --git a/tests/trainer/test_early_stopping.py b/tests/trainer/test_early_stopping.py index b8893f9..dc794f4 100644 --- a/tests/trainer/test_early_stopping.py +++ b/tests/trainer/test_early_stopping.py @@ -25,6 +25,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset): checkpoint_interval=1, accumulation_steps=2, random_seed=np.random.randint(1e4), + device_type=base_test_env["device"] ) trainer = Trainer(train_config) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 27261ce..ce93bec 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -30,7 +30,8 @@ def test_different_batch_sizes(base_test_env, random_dataset): checkpoint_interval=5, accumulation_steps=1, max_grad_norm=1.0, - random_seed=np.random.randint(1000) + random_seed=np.random.randint(1000), + device_type=base_test_env["device"] ) assert train_config.batch_size == batch_size @@ -59,7 +60,8 @@ def test_gradient_accumulation(base_test_env, random_dataset): checkpoint_interval=10, accumulation_steps=accumulation_steps, max_grad_norm=1.0, - random_seed=42 + random_seed=42, + device_type=base_test_env["device"] ) trainer = Trainer(train_config) @@ -96,7 +98,8 @@ def test_memory_efficient_training(base_test_env, random_dataset): checkpoint_interval=5, accumulation_steps=config["accumulation_steps"], max_grad_norm=1.0, - random_seed=42 + random_seed=42, + device_type=base_test_env["device"] ) assert train_config.accumulation_steps == config["accumulation_steps"] \ No newline at end of file