52 lines
1.5 KiB
Python
52 lines
1.5 KiB
Python
from pathlib import Path
|
|
import tempfile
|
|
import torch
|
|
from torch.optim import AdamW
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
from khaosz.data.checkpoint import Checkpoint
|
|
|
|
def test_single_process():
|
|
model = torch.nn.Linear(10, 5)
|
|
optimizer = AdamW(model.parameters(), lr=1e-3)
|
|
scheduler = CosineAnnealingLR(optimizer, T_max=10)
|
|
|
|
for epoch in range(3):
|
|
for iteration in range(10):
|
|
|
|
x = torch.randn(32, 10)
|
|
y = torch.randn(32, 5)
|
|
loss = model(x).mean()
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
scheduler.step()
|
|
|
|
checkpoint = Checkpoint(
|
|
optimizer_state_dict=optimizer.state_dict(),
|
|
scheduler_state_dict=scheduler.state_dict(),
|
|
epoch=3,
|
|
iteration=30,
|
|
metrics={
|
|
"loss": [0.5, 0.4, 0.3, 0.2, 0.1],
|
|
"accuracy": [0.6, 0.7, 0.8, 0.85, 0.9]
|
|
}
|
|
)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
checkpoint.save(tmpdir, save_metric_plot=True)
|
|
|
|
loaded_checkpoint = Checkpoint.load(tmpdir)
|
|
|
|
assert loaded_checkpoint.epoch == 3
|
|
assert loaded_checkpoint.iteration == 30
|
|
assert loaded_checkpoint.metrics["loss"] == [0.5, 0.4, 0.3, 0.2, 0.1]
|
|
|
|
assert 'param_groups' in loaded_checkpoint.optimizer_state_dict
|
|
assert 'state' in loaded_checkpoint.optimizer_state_dict
|
|
|
|
png_files = list(Path(tmpdir).glob("*.png"))
|
|
assert png_files
|
|
|
|
def test_multi_process():
|
|
pass |