test(checkpoint): 添加多进程检查点测试功能
This commit is contained in:
parent
ff5c8a71f5
commit
9dab96c31f
|
|
@ -1,9 +1,12 @@
|
|||
from pathlib import Path
|
||||
import tempfile
|
||||
import os
|
||||
import torch
|
||||
import tempfile
|
||||
|
||||
from pathlib import Path
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from khaosz.data.checkpoint import Checkpoint
|
||||
from khaosz.parallel.setup import spawn_parallel_fn
|
||||
|
||||
def test_single_process():
|
||||
model = torch.nn.Linear(10, 5)
|
||||
|
|
@ -48,5 +51,42 @@ def test_single_process():
|
|||
png_files = list(Path(tmpdir).glob("*.png"))
|
||||
assert png_files
|
||||
|
||||
def simple_training():
|
||||
rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||
|
||||
# 简单的训练逻辑
|
||||
model = torch.nn.Linear(10, 5)
|
||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=10)
|
||||
|
||||
# 训练步骤
|
||||
for epoch in range(2):
|
||||
for iteration in range(5):
|
||||
x = torch.randn(16, 10)
|
||||
y = torch.randn(16, 5)
|
||||
loss = model(x).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
scheduler.step()
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
optimizer_state_dict=optimizer.state_dict(),
|
||||
scheduler_state_dict=scheduler.state_dict(),
|
||||
epoch=2,
|
||||
iteration=10,
|
||||
metrics={"loss": [0.3, 0.2, 0.1]}
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
checkpoint.save(tmpdir)
|
||||
loaded = Checkpoint.load(tmpdir)
|
||||
assert loaded.epoch == 2
|
||||
print(f"Rank {rank}: Checkpoint test passed")
|
||||
|
||||
def test_multi_process():
|
||||
pass
|
||||
spawn_parallel_fn(
|
||||
simple_training,
|
||||
world_size=2,
|
||||
backend="gloo"
|
||||
)
|
||||
Loading…
Reference in New Issue