fix(trainer): 更新检查点保存和加载逻辑

This commit is contained in:
ViperEkura 2026-01-08 19:04:08 +08:00
parent 3d8047fa1b
commit d407962ffa
7 changed files with 66 additions and 24 deletions

View File

@ -1,15 +1,12 @@
import os import os
import json import json
import torch
import torch.distributed as dist
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Any from typing import Dict, Optional, Any
from khaosz.parallel.setup import get_rank
import torch.distributed as dist
from torch.distributed.checkpoint import save, load
def get_rank() -> int:
return dist.get_rank() if dist.is_initialized() else 0
class Checkpoint: class Checkpoint:
@ -53,8 +50,8 @@ class Checkpoint:
"optimizer": self.optimizer_state_dict, "optimizer": self.optimizer_state_dict,
"scheduler": self.scheduler_state_dict "scheduler": self.scheduler_state_dict
} }
with open(save_path / f"state_dict_rank_{get_rank()}.pt", "wb") as f:
save(state_dict, checkpoint_id=str(save_path)) torch.save(state_dict, f)
@classmethod @classmethod
def load( def load(
@ -62,9 +59,9 @@ class Checkpoint:
save_dir: str, save_dir: str,
) -> "Checkpoint": ) -> "Checkpoint":
save_path = str(Path(save_dir))
rank = get_rank() rank = get_rank()
save_path = Path(save_dir)
meta = {} meta = {}
if rank == 0: if rank == 0:
with open(Path(save_dir) / "meta.json", "r") as f: with open(Path(save_dir) / "meta.json", "r") as f:
@ -75,11 +72,8 @@ class Checkpoint:
dist.broadcast_object_list(meta_list, src=0) dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0] meta = meta_list[0]
state_dict = { with open(save_path / f"state_dict_rank_{get_rank()}.pt", "rb") as f:
"optimizer": {}, state_dict = torch.load(f)
"scheduler": {}
}
load(state_dict, checkpoint_id=save_path, no_dist=True)
return cls( return cls(
optimizer_state_dict=state_dict["optimizer"], optimizer_state_dict=state_dict["optimizer"],

View File

@ -1,5 +1,4 @@
from khaosz.trainer.trainer import Trainer from khaosz.trainer.trainer import Trainer
from khaosz.trainer.checkpoint import Checkpoint
from khaosz.trainer.strategy import StrategyFactory from khaosz.trainer.strategy import StrategyFactory
from khaosz.trainer.schedule import SchedulerFactory from khaosz.trainer.schedule import SchedulerFactory
@ -16,9 +15,6 @@ __all__ = [
# trainer # trainer
"Trainer", "Trainer",
# checkpoint
"Checkpoint",
# factory # factory
"StrategyFactory", "StrategyFactory",
"SchedulerFactory", "SchedulerFactory",

View File

@ -17,7 +17,7 @@ from khaosz.trainer.metric_util import (
grad_std, grad_std,
grad_nan_num grad_nan_num
) )
from khaosz.trainer.checkpoint import Checkpoint from khaosz.data.checkpoint import Checkpoint
if TYPE_CHECKING: if TYPE_CHECKING:
from khaosz.trainer.train_context import TrainContext from khaosz.trainer.train_context import TrainContext

View File

@ -4,7 +4,7 @@ from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from khaosz.data import ResumableDistributedSampler from khaosz.data import ResumableDistributedSampler
from khaosz.trainer.checkpoint import Checkpoint from khaosz.data.checkpoint import Checkpoint
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
from khaosz.config.train_config import TrainConfig from khaosz.config.train_config import TrainConfig
from khaosz.parallel.setup import get_current_device, get_world_size, get_rank from khaosz.parallel.setup import get_current_device, get_world_size, get_rank

View File

@ -9,7 +9,7 @@ from khaosz.trainer.train_callback import (
SchedulerCallback SchedulerCallback
) )
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
from khaosz.trainer.checkpoint import Checkpoint from khaosz.data.checkpoint import Checkpoint
from khaosz.parallel.setup import spawn_parallel_fn from khaosz.parallel.setup import spawn_parallel_fn
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

52
tests/test_checkpoint.py Normal file
View File

@ -0,0 +1,52 @@
from pathlib import Path
import tempfile
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from khaosz.data.checkpoint import Checkpoint
def test_single_process():
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=10)
for epoch in range(3):
for iteration in range(10):
x = torch.randn(32, 10)
y = torch.randn(32, 5)
loss = model(x).mean()
loss.backward()
optimizer.step()
optimizer.zero_grad()
scheduler.step()
checkpoint = Checkpoint(
optimizer_state_dict=optimizer.state_dict(),
scheduler_state_dict=scheduler.state_dict(),
epoch=3,
iteration=30,
metrics={
"loss": [0.5, 0.4, 0.3, 0.2, 0.1],
"accuracy": [0.6, 0.7, 0.8, 0.85, 0.9]
}
)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir, save_metric_plot=True)
loaded_checkpoint = Checkpoint.load(tmpdir)
assert loaded_checkpoint.epoch == 3
assert loaded_checkpoint.iteration == 30
assert loaded_checkpoint.metrics["loss"] == [0.5, 0.4, 0.3, 0.2, 0.1]
assert 'param_groups' in loaded_checkpoint.optimizer_state_dict
assert 'state' in loaded_checkpoint.optimizer_state_dict
png_files = list(Path(tmpdir).glob("*.png"))
assert png_files
def test_multi_process():
pass

View File

@ -3,7 +3,7 @@ import torch
import numpy as np import numpy as np
from khaosz.config import * from khaosz.config import *
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.data.checkpoint import Checkpoint
def test_early_stopping_simulation(base_test_env, early_stopping_dataset): def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
"""Simulate early stopping behavior""" """Simulate early stopping behavior"""