AstrAI/tests/data/test_checkpoint.py

78 lines
2.0 KiB
Python

import torch
import tempfile
import torch.distributed as dist
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from astrai.data.serialization import Checkpoint
from astrai.parallel.setup import get_rank, spawn_parallel_fn
def test_single_process():
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=10)
for epoch in range(3):
for iteration in range(10):
x = torch.randn(32, 10)
y = torch.randn(32, 5)
loss = model(x).mean()
loss.backward()
optimizer.step()
optimizer.zero_grad()
scheduler.step()
checkpoint = Checkpoint(state_dict=model.state_dict(), epoch=3, iteration=30)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir)
loaded_checkpoint = Checkpoint.load(tmpdir)
assert loaded_checkpoint.epoch == 3
assert loaded_checkpoint.iteration == 30
def simple_training():
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=10)
for epoch in range(2):
for iteration in range(5):
x = torch.randn(16, 10)
y = torch.randn(16, 5)
loss = model(x).mean()
loss.backward()
optimizer.step()
optimizer.zero_grad()
scheduler.step()
checkpoint = Checkpoint(
state_dict=model.state_dict(),
epoch=2,
iteration=10,
)
rank = get_rank()
if rank == 0:
shared_dir = tempfile.mkdtemp()
checkpoint.save(shared_dir)
else:
shared_dir = None
if dist.is_initialized():
dir_list = [shared_dir]
dist.broadcast_object_list(dir_list, src=0)
shared_dir = dir_list[0]
loaded = Checkpoint.load(shared_dir)
assert loaded.epoch == 2
def test_multi_process():
spawn_parallel_fn(simple_training, world_size=2, backend="gloo")