AstrAI/khaosz/data/checkpoint.py

91 lines
2.5 KiB
Python

import os
import json
import torch
import torch.distributed as dist
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Dict, Optional, Any
from khaosz.parallel.setup import get_rank
class Checkpoint:
def __init__(
self,
state_dict: Dict[str, Any],
epoch: int = 0,
iteration: int = 0,
metrics: Optional[Dict[str, list]] = None,
):
self.state_dict = state_dict
self.epoch = epoch
self.iteration = iteration
self.metrics = metrics or {}
def save(
self,
save_dir: str,
save_metric_plot: bool = True,
) -> None:
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)
rank = get_rank()
if rank == 0:
meta = {
"epoch": self.epoch,
"iteration": self.iteration,
"metrics": self.metrics,
}
with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2)
if save_metric_plot and self.metrics:
self._plot_metrics(str(save_path))
with open(save_path / f"state_dict.pt", "wb") as f:
torch.save(self.state_dict, f)
@classmethod
def load(
cls,
save_dir: str,
) -> "Checkpoint":
rank = get_rank()
save_path = Path(save_dir)
meta = {}
if rank == 0:
with open(Path(save_dir) / "meta.json", "r") as f:
meta = json.load(f)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
with open(save_path / f"state_dict.pt", "rb") as f:
state_dict = torch.load(f)
return cls(
state_dict=state_dict,
epoch=meta["epoch"],
iteration=meta["iteration"],
metrics=meta.get("metrics", {}),
)
def _plot_metrics(self, save_dir: str):
for name, values in self.metrics.items():
if not values:
continue
plt.figure(figsize=(10, 6))
plt.plot(values, label=name)
plt.xlabel("Step")
plt.ylabel("Value")
plt.title(f"Training Metric: {name}")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(os.path.join(save_dir, f"{name}.png"), dpi=150, bbox_inches="tight")
plt.close()