diff --git a/tests/data/test_checkpoint.py b/tests/data/test_checkpoint.py index 79633df..ec1dc46 100644 --- a/tests/data/test_checkpoint.py +++ b/tests/data/test_checkpoint.py @@ -1,9 +1,12 @@ -from pathlib import Path -import tempfile +import os import torch +import tempfile + +from pathlib import Path from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR from khaosz.data.checkpoint import Checkpoint +from khaosz.parallel.setup import spawn_parallel_fn def test_single_process(): model = torch.nn.Linear(10, 5) @@ -48,5 +51,42 @@ def test_single_process(): png_files = list(Path(tmpdir).glob("*.png")) assert png_files +def simple_training(): + rank = int(os.environ.get('LOCAL_RANK', 0)) + + # 简单的训练逻辑 + model = torch.nn.Linear(10, 5) + optimizer = AdamW(model.parameters(), lr=1e-3) + scheduler = CosineAnnealingLR(optimizer, T_max=10) + + # 训练步骤 + for epoch in range(2): + for iteration in range(5): + x = torch.randn(16, 10) + y = torch.randn(16, 5) + loss = model(x).mean() + loss.backward() + optimizer.step() + optimizer.zero_grad() + scheduler.step() + + checkpoint = Checkpoint( + optimizer_state_dict=optimizer.state_dict(), + scheduler_state_dict=scheduler.state_dict(), + epoch=2, + iteration=10, + metrics={"loss": [0.3, 0.2, 0.1]} + ) + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint.save(tmpdir) + loaded = Checkpoint.load(tmpdir) + assert loaded.epoch == 2 + print(f"Rank {rank}: Checkpoint test passed") + def test_multi_process(): - pass \ No newline at end of file + spawn_parallel_fn( + simple_training, + world_size=2, + backend="gloo" + ) \ No newline at end of file