fix(trainer): 更新检查点保存和加载逻辑
This commit is contained in:
parent
3d8047fa1b
commit
d407962ffa
|
|
@ -1,15 +1,12 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.checkpoint import save, load
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
return dist.get_rank() if dist.is_initialized() else 0
|
||||
from khaosz.parallel.setup import get_rank
|
||||
|
||||
|
||||
class Checkpoint:
|
||||
|
|
@ -53,8 +50,8 @@ class Checkpoint:
|
|||
"optimizer": self.optimizer_state_dict,
|
||||
"scheduler": self.scheduler_state_dict
|
||||
}
|
||||
|
||||
save(state_dict, checkpoint_id=str(save_path))
|
||||
with open(save_path / f"state_dict_rank_{get_rank()}.pt", "wb") as f:
|
||||
torch.save(state_dict, f)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
|
|
@ -62,9 +59,9 @@ class Checkpoint:
|
|||
save_dir: str,
|
||||
) -> "Checkpoint":
|
||||
|
||||
save_path = str(Path(save_dir))
|
||||
rank = get_rank()
|
||||
|
||||
save_path = Path(save_dir)
|
||||
|
||||
meta = {}
|
||||
if rank == 0:
|
||||
with open(Path(save_dir) / "meta.json", "r") as f:
|
||||
|
|
@ -75,11 +72,8 @@ class Checkpoint:
|
|||
dist.broadcast_object_list(meta_list, src=0)
|
||||
meta = meta_list[0]
|
||||
|
||||
state_dict = {
|
||||
"optimizer": {},
|
||||
"scheduler": {}
|
||||
}
|
||||
load(state_dict, checkpoint_id=save_path, no_dist=True)
|
||||
with open(save_path / f"state_dict_rank_{get_rank()}.pt", "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
|
||||
return cls(
|
||||
optimizer_state_dict=state_dict["optimizer"],
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
from khaosz.trainer.trainer import Trainer
|
||||
from khaosz.trainer.checkpoint import Checkpoint
|
||||
from khaosz.trainer.strategy import StrategyFactory
|
||||
from khaosz.trainer.schedule import SchedulerFactory
|
||||
|
||||
|
|
@ -16,9 +15,6 @@ __all__ = [
|
|||
# trainer
|
||||
"Trainer",
|
||||
|
||||
# checkpoint
|
||||
"Checkpoint",
|
||||
|
||||
# factory
|
||||
"StrategyFactory",
|
||||
"SchedulerFactory",
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from khaosz.trainer.metric_util import (
|
|||
grad_std,
|
||||
grad_nan_num
|
||||
)
|
||||
from khaosz.trainer.checkpoint import Checkpoint
|
||||
from khaosz.data.checkpoint import Checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from khaosz.trainer.train_context import TrainContext
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from torch.optim.lr_scheduler import LRScheduler
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from khaosz.data import ResumableDistributedSampler
|
||||
from khaosz.trainer.checkpoint import Checkpoint
|
||||
from khaosz.data.checkpoint import Checkpoint
|
||||
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
|
||||
from khaosz.config.train_config import TrainConfig
|
||||
from khaosz.parallel.setup import get_current_device, get_world_size, get_rank
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from khaosz.trainer.train_callback import (
|
|||
SchedulerCallback
|
||||
)
|
||||
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
||||
from khaosz.trainer.checkpoint import Checkpoint
|
||||
from khaosz.data.checkpoint import Checkpoint
|
||||
from khaosz.parallel.setup import spawn_parallel_fn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,52 @@
|
|||
from pathlib import Path
|
||||
import tempfile
|
||||
import torch
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from khaosz.data.checkpoint import Checkpoint
|
||||
|
||||
def test_single_process():
|
||||
model = torch.nn.Linear(10, 5)
|
||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=10)
|
||||
|
||||
for epoch in range(3):
|
||||
for iteration in range(10):
|
||||
|
||||
x = torch.randn(32, 10)
|
||||
y = torch.randn(32, 5)
|
||||
loss = model(x).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
scheduler.step()
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
optimizer_state_dict=optimizer.state_dict(),
|
||||
scheduler_state_dict=scheduler.state_dict(),
|
||||
epoch=3,
|
||||
iteration=30,
|
||||
metrics={
|
||||
"loss": [0.5, 0.4, 0.3, 0.2, 0.1],
|
||||
"accuracy": [0.6, 0.7, 0.8, 0.85, 0.9]
|
||||
}
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
checkpoint.save(tmpdir, save_metric_plot=True)
|
||||
|
||||
loaded_checkpoint = Checkpoint.load(tmpdir)
|
||||
|
||||
assert loaded_checkpoint.epoch == 3
|
||||
assert loaded_checkpoint.iteration == 30
|
||||
assert loaded_checkpoint.metrics["loss"] == [0.5, 0.4, 0.3, 0.2, 0.1]
|
||||
|
||||
assert 'param_groups' in loaded_checkpoint.optimizer_state_dict
|
||||
assert 'state' in loaded_checkpoint.optimizer_state_dict
|
||||
|
||||
png_files = list(Path(tmpdir).glob("*.png"))
|
||||
assert png_files
|
||||
|
||||
def test_multi_process():
|
||||
pass
|
||||
|
|
@ -3,7 +3,7 @@ import torch
|
|||
import numpy as np
|
||||
from khaosz.config import *
|
||||
from khaosz.trainer import *
|
||||
|
||||
from khaosz.data.checkpoint import Checkpoint
|
||||
|
||||
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||
"""Simulate early stopping behavior"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue