test(checkpoint): 添加多进程检查点测试功能

This commit is contained in:
ViperEkura 2026-01-08 22:04:39 +08:00
parent ff5c8a71f5
commit 9dab96c31f
1 changed files with 43 additions and 3 deletions

View File

@ -1,9 +1,12 @@
from pathlib import Path import os
import tempfile
import torch import torch
import tempfile
from pathlib import Path
from torch.optim import AdamW from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from khaosz.data.checkpoint import Checkpoint from khaosz.data.checkpoint import Checkpoint
from khaosz.parallel.setup import spawn_parallel_fn
def test_single_process(): def test_single_process():
model = torch.nn.Linear(10, 5) model = torch.nn.Linear(10, 5)
@ -48,5 +51,42 @@ def test_single_process():
png_files = list(Path(tmpdir).glob("*.png")) png_files = list(Path(tmpdir).glob("*.png"))
assert png_files assert png_files
def simple_training():
rank = int(os.environ.get('LOCAL_RANK', 0))
# 简单的训练逻辑
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=10)
# 训练步骤
for epoch in range(2):
for iteration in range(5):
x = torch.randn(16, 10)
y = torch.randn(16, 5)
loss = model(x).mean()
loss.backward()
optimizer.step()
optimizer.zero_grad()
scheduler.step()
checkpoint = Checkpoint(
optimizer_state_dict=optimizer.state_dict(),
scheduler_state_dict=scheduler.state_dict(),
epoch=2,
iteration=10,
metrics={"loss": [0.3, 0.2, 0.1]}
)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir)
loaded = Checkpoint.load(tmpdir)
assert loaded.epoch == 2
print(f"Rank {rank}: Checkpoint test passed")
def test_multi_process(): def test_multi_process():
pass spawn_parallel_fn(
simple_training,
world_size=2,
backend="gloo"
)