From 9dab96c31f5a5831eab9c49477f26f3f1a1a883e Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 8 Jan 2026 22:04:39 +0800 Subject: [PATCH] =?UTF-8?q?test(checkpoint):=20=E6=B7=BB=E5=8A=A0=E5=A4=9A?= =?UTF-8?q?=E8=BF=9B=E7=A8=8B=E6=A3=80=E6=9F=A5=E7=82=B9=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/data/test_checkpoint.py | 46 ++++++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/tests/data/test_checkpoint.py b/tests/data/test_checkpoint.py index 79633df..ec1dc46 100644 --- a/tests/data/test_checkpoint.py +++ b/tests/data/test_checkpoint.py @@ -1,9 +1,12 @@ -from pathlib import Path -import tempfile +import os import torch +import tempfile + +from pathlib import Path from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR from khaosz.data.checkpoint import Checkpoint +from khaosz.parallel.setup import spawn_parallel_fn def test_single_process(): model = torch.nn.Linear(10, 5) @@ -48,5 +51,42 @@ def test_single_process(): png_files = list(Path(tmpdir).glob("*.png")) assert png_files +def simple_training(): + rank = int(os.environ.get('LOCAL_RANK', 0)) + + # 简单的训练逻辑 + model = torch.nn.Linear(10, 5) + optimizer = AdamW(model.parameters(), lr=1e-3) + scheduler = CosineAnnealingLR(optimizer, T_max=10) + + # 训练步骤 + for epoch in range(2): + for iteration in range(5): + x = torch.randn(16, 10) + y = torch.randn(16, 5) + loss = model(x).mean() + loss.backward() + optimizer.step() + optimizer.zero_grad() + scheduler.step() + + checkpoint = Checkpoint( + optimizer_state_dict=optimizer.state_dict(), + scheduler_state_dict=scheduler.state_dict(), + epoch=2, + iteration=10, + metrics={"loss": [0.3, 0.2, 0.1]} + ) + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint.save(tmpdir) + loaded = Checkpoint.load(tmpdir) + assert loaded.epoch == 2 + print(f"Rank {rank}: Checkpoint test passed") + def test_multi_process(): - pass \ No newline at end of file + spawn_parallel_fn( + simple_training, + world_size=2, + backend="gloo" + ) \ No newline at end of file