fix(trainer): 更新检查点保存和加载逻辑

This commit is contained in:
ViperEkura 2026-01-08 19:04:08 +08:00
parent 3d8047fa1b
commit d407962ffa
7 changed files with 66 additions and 24 deletions

View File

@ -1,15 +1,12 @@
import os
import json
import torch
import torch.distributed as dist
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Dict, Optional, Any
import torch.distributed as dist
from torch.distributed.checkpoint import save, load
def get_rank() -> int:
return dist.get_rank() if dist.is_initialized() else 0
from khaosz.parallel.setup import get_rank
class Checkpoint:
@ -53,8 +50,8 @@ class Checkpoint:
"optimizer": self.optimizer_state_dict,
"scheduler": self.scheduler_state_dict
}
save(state_dict, checkpoint_id=str(save_path))
with open(save_path / f"state_dict_rank_{get_rank()}.pt", "wb") as f:
torch.save(state_dict, f)
@classmethod
def load(
@ -62,9 +59,9 @@ class Checkpoint:
save_dir: str,
) -> "Checkpoint":
save_path = str(Path(save_dir))
rank = get_rank()
save_path = Path(save_dir)
meta = {}
if rank == 0:
with open(Path(save_dir) / "meta.json", "r") as f:
@ -75,11 +72,8 @@ class Checkpoint:
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
state_dict = {
"optimizer": {},
"scheduler": {}
}
load(state_dict, checkpoint_id=save_path, no_dist=True)
with open(save_path / f"state_dict_rank_{get_rank()}.pt", "rb") as f:
state_dict = torch.load(f)
return cls(
optimizer_state_dict=state_dict["optimizer"],

View File

@ -1,5 +1,4 @@
from khaosz.trainer.trainer import Trainer
from khaosz.trainer.checkpoint import Checkpoint
from khaosz.trainer.strategy import StrategyFactory
from khaosz.trainer.schedule import SchedulerFactory
@ -16,9 +15,6 @@ __all__ = [
# trainer
"Trainer",
# checkpoint
"Checkpoint",
# factory
"StrategyFactory",
"SchedulerFactory",

View File

@ -17,7 +17,7 @@ from khaosz.trainer.metric_util import (
grad_std,
grad_nan_num
)
from khaosz.trainer.checkpoint import Checkpoint
from khaosz.data.checkpoint import Checkpoint
if TYPE_CHECKING:
from khaosz.trainer.train_context import TrainContext

View File

@ -4,7 +4,7 @@ from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from khaosz.data import ResumableDistributedSampler
from khaosz.trainer.checkpoint import Checkpoint
from khaosz.data.checkpoint import Checkpoint
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
from khaosz.config.train_config import TrainConfig
from khaosz.parallel.setup import get_current_device, get_world_size, get_rank

View File

@ -9,7 +9,7 @@ from khaosz.trainer.train_callback import (
SchedulerCallback
)
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
from khaosz.trainer.checkpoint import Checkpoint
from khaosz.data.checkpoint import Checkpoint
from khaosz.parallel.setup import spawn_parallel_fn
logger = logging.getLogger(__name__)

52
tests/test_checkpoint.py Normal file
View File

@ -0,0 +1,52 @@
from pathlib import Path
import tempfile
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from khaosz.data.checkpoint import Checkpoint
def test_single_process():
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=10)
for epoch in range(3):
for iteration in range(10):
x = torch.randn(32, 10)
y = torch.randn(32, 5)
loss = model(x).mean()
loss.backward()
optimizer.step()
optimizer.zero_grad()
scheduler.step()
checkpoint = Checkpoint(
optimizer_state_dict=optimizer.state_dict(),
scheduler_state_dict=scheduler.state_dict(),
epoch=3,
iteration=30,
metrics={
"loss": [0.5, 0.4, 0.3, 0.2, 0.1],
"accuracy": [0.6, 0.7, 0.8, 0.85, 0.9]
}
)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir, save_metric_plot=True)
loaded_checkpoint = Checkpoint.load(tmpdir)
assert loaded_checkpoint.epoch == 3
assert loaded_checkpoint.iteration == 30
assert loaded_checkpoint.metrics["loss"] == [0.5, 0.4, 0.3, 0.2, 0.1]
assert 'param_groups' in loaded_checkpoint.optimizer_state_dict
assert 'state' in loaded_checkpoint.optimizer_state_dict
png_files = list(Path(tmpdir).glob("*.png"))
assert png_files
def test_multi_process():
pass

View File

@ -3,7 +3,7 @@ import torch
import numpy as np
from khaosz.config import *
from khaosz.trainer import *
from khaosz.data.checkpoint import Checkpoint
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
"""Simulate early stopping behavior"""