test(checkpoint): 添加多进程检查点测试功能
This commit is contained in:
parent
ff5c8a71f5
commit
9dab96c31f
|
|
@ -1,9 +1,12 @@
|
||||||
from pathlib import Path
|
import os
|
||||||
import tempfile
|
|
||||||
import torch
|
import torch
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
from khaosz.data.checkpoint import Checkpoint
|
from khaosz.data.checkpoint import Checkpoint
|
||||||
|
from khaosz.parallel.setup import spawn_parallel_fn
|
||||||
|
|
||||||
def test_single_process():
|
def test_single_process():
|
||||||
model = torch.nn.Linear(10, 5)
|
model = torch.nn.Linear(10, 5)
|
||||||
|
|
@ -48,5 +51,42 @@ def test_single_process():
|
||||||
png_files = list(Path(tmpdir).glob("*.png"))
|
png_files = list(Path(tmpdir).glob("*.png"))
|
||||||
assert png_files
|
assert png_files
|
||||||
|
|
||||||
|
def simple_training():
|
||||||
|
rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||||
|
|
||||||
|
# 简单的训练逻辑
|
||||||
|
model = torch.nn.Linear(10, 5)
|
||||||
|
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||||
|
scheduler = CosineAnnealingLR(optimizer, T_max=10)
|
||||||
|
|
||||||
|
# 训练步骤
|
||||||
|
for epoch in range(2):
|
||||||
|
for iteration in range(5):
|
||||||
|
x = torch.randn(16, 10)
|
||||||
|
y = torch.randn(16, 5)
|
||||||
|
loss = model(x).mean()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
checkpoint = Checkpoint(
|
||||||
|
optimizer_state_dict=optimizer.state_dict(),
|
||||||
|
scheduler_state_dict=scheduler.state_dict(),
|
||||||
|
epoch=2,
|
||||||
|
iteration=10,
|
||||||
|
metrics={"loss": [0.3, 0.2, 0.1]}
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
checkpoint.save(tmpdir)
|
||||||
|
loaded = Checkpoint.load(tmpdir)
|
||||||
|
assert loaded.epoch == 2
|
||||||
|
print(f"Rank {rank}: Checkpoint test passed")
|
||||||
|
|
||||||
def test_multi_process():
|
def test_multi_process():
|
||||||
pass
|
spawn_parallel_fn(
|
||||||
|
simple_training,
|
||||||
|
world_size=2,
|
||||||
|
backend="gloo"
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue