From 345fd2f0914f31622177e3c6b76dc02c927e65a5 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 30 Mar 2026 22:22:36 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E4=BC=A0=E9=80=92=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/parallel/setup.py | 1 + khaosz/trainer/trainer.py | 2 ++ tests/conftest.py | 2 +- tests/trainer/test_callbacks.py | 3 ++- tests/trainer/test_early_stopping.py | 1 + tests/trainer/test_trainer.py | 9 ++++++--- 6 files changed, 13 insertions(+), 5 deletions(-) diff --git a/khaosz/parallel/setup.py b/khaosz/parallel/setup.py index 452e30f..4c81b19 100644 --- a/khaosz/parallel/setup.py +++ b/khaosz/parallel/setup.py @@ -81,6 +81,7 @@ def only_on_rank(rank, sync=False): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): + ret_args = None if get_rank() == rank: ret_args = func(*args, **kwargs) diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index a75ef36..6d2415f 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -57,6 +57,8 @@ class Trainer: world_size=config.nprocs, master_addr=config.master_addr, master_port=config.master_port, + device_type=config.device_type, + device_ids=config.device_ids, checkpoint=checkpoint ) diff --git a/tests/conftest.py b/tests/conftest.py index 032d3f7..2dacbbb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -96,7 +96,7 @@ def base_test_env(request: pytest.FixtureRequest): with open(config_path, 'w') as f: json.dump(config, f) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = "cuda" if torch.cuda.is_available() else "cpu" transformer_config = ModelConfig().load(config_path) model = Transformer(transformer_config).to(device=device) tokenizer = BpeTokenizer() diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index 7855d7d..0ae63d8 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -25,7 +25,8 @@ def test_callback_integration(base_test_env, random_dataset): checkpoint_interval=3, accumulation_steps=1, max_grad_norm=1.0, - random_seed=42 + random_seed=42, + device_type=base_test_env["device"] ) diff --git a/tests/trainer/test_early_stopping.py b/tests/trainer/test_early_stopping.py index b8893f9..dc794f4 100644 --- a/tests/trainer/test_early_stopping.py +++ b/tests/trainer/test_early_stopping.py @@ -25,6 +25,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset): checkpoint_interval=1, accumulation_steps=2, random_seed=np.random.randint(1e4), + device_type=base_test_env["device"] ) trainer = Trainer(train_config) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 27261ce..ce93bec 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -30,7 +30,8 @@ def test_different_batch_sizes(base_test_env, random_dataset): checkpoint_interval=5, accumulation_steps=1, max_grad_norm=1.0, - random_seed=np.random.randint(1000) + random_seed=np.random.randint(1000), + device_type=base_test_env["device"] ) assert train_config.batch_size == batch_size @@ -59,7 +60,8 @@ def test_gradient_accumulation(base_test_env, random_dataset): checkpoint_interval=10, accumulation_steps=accumulation_steps, max_grad_norm=1.0, - random_seed=42 + random_seed=42, + device_type=base_test_env["device"] ) trainer = Trainer(train_config) @@ -96,7 +98,8 @@ def test_memory_efficient_training(base_test_env, random_dataset): checkpoint_interval=5, accumulation_steps=config["accumulation_steps"], max_grad_norm=1.0, - random_seed=42 + random_seed=42, + device_type=base_test_env["device"] ) assert train_config.accumulation_steps == config["accumulation_steps"] \ No newline at end of file