fix(trainer): 更新检查点保存和加载逻辑
This commit is contained in:
parent
3d8047fa1b
commit
d407962ffa
|
|
@ -1,15 +1,12 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Any
|
from typing import Dict, Optional, Any
|
||||||
|
from khaosz.parallel.setup import get_rank
|
||||||
import torch.distributed as dist
|
|
||||||
from torch.distributed.checkpoint import save, load
|
|
||||||
|
|
||||||
|
|
||||||
def get_rank() -> int:
|
|
||||||
return dist.get_rank() if dist.is_initialized() else 0
|
|
||||||
|
|
||||||
|
|
||||||
class Checkpoint:
|
class Checkpoint:
|
||||||
|
|
@ -53,8 +50,8 @@ class Checkpoint:
|
||||||
"optimizer": self.optimizer_state_dict,
|
"optimizer": self.optimizer_state_dict,
|
||||||
"scheduler": self.scheduler_state_dict
|
"scheduler": self.scheduler_state_dict
|
||||||
}
|
}
|
||||||
|
with open(save_path / f"state_dict_rank_{get_rank()}.pt", "wb") as f:
|
||||||
save(state_dict, checkpoint_id=str(save_path))
|
torch.save(state_dict, f)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
|
|
@ -62,9 +59,9 @@ class Checkpoint:
|
||||||
save_dir: str,
|
save_dir: str,
|
||||||
) -> "Checkpoint":
|
) -> "Checkpoint":
|
||||||
|
|
||||||
save_path = str(Path(save_dir))
|
|
||||||
rank = get_rank()
|
rank = get_rank()
|
||||||
|
save_path = Path(save_dir)
|
||||||
|
|
||||||
meta = {}
|
meta = {}
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
with open(Path(save_dir) / "meta.json", "r") as f:
|
with open(Path(save_dir) / "meta.json", "r") as f:
|
||||||
|
|
@ -75,11 +72,8 @@ class Checkpoint:
|
||||||
dist.broadcast_object_list(meta_list, src=0)
|
dist.broadcast_object_list(meta_list, src=0)
|
||||||
meta = meta_list[0]
|
meta = meta_list[0]
|
||||||
|
|
||||||
state_dict = {
|
with open(save_path / f"state_dict_rank_{get_rank()}.pt", "rb") as f:
|
||||||
"optimizer": {},
|
state_dict = torch.load(f)
|
||||||
"scheduler": {}
|
|
||||||
}
|
|
||||||
load(state_dict, checkpoint_id=save_path, no_dist=True)
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
optimizer_state_dict=state_dict["optimizer"],
|
optimizer_state_dict=state_dict["optimizer"],
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from khaosz.trainer.trainer import Trainer
|
from khaosz.trainer.trainer import Trainer
|
||||||
from khaosz.trainer.checkpoint import Checkpoint
|
|
||||||
from khaosz.trainer.strategy import StrategyFactory
|
from khaosz.trainer.strategy import StrategyFactory
|
||||||
from khaosz.trainer.schedule import SchedulerFactory
|
from khaosz.trainer.schedule import SchedulerFactory
|
||||||
|
|
||||||
|
|
@ -16,9 +15,6 @@ __all__ = [
|
||||||
# trainer
|
# trainer
|
||||||
"Trainer",
|
"Trainer",
|
||||||
|
|
||||||
# checkpoint
|
|
||||||
"Checkpoint",
|
|
||||||
|
|
||||||
# factory
|
# factory
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
"SchedulerFactory",
|
"SchedulerFactory",
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from khaosz.trainer.metric_util import (
|
||||||
grad_std,
|
grad_std,
|
||||||
grad_nan_num
|
grad_nan_num
|
||||||
)
|
)
|
||||||
from khaosz.trainer.checkpoint import Checkpoint
|
from khaosz.data.checkpoint import Checkpoint
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from khaosz.trainer.train_context import TrainContext
|
from khaosz.trainer.train_context import TrainContext
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from torch.optim.lr_scheduler import LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from khaosz.data import ResumableDistributedSampler
|
from khaosz.data import ResumableDistributedSampler
|
||||||
from khaosz.trainer.checkpoint import Checkpoint
|
from khaosz.data.checkpoint import Checkpoint
|
||||||
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
|
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
|
||||||
from khaosz.config.train_config import TrainConfig
|
from khaosz.config.train_config import TrainConfig
|
||||||
from khaosz.parallel.setup import get_current_device, get_world_size, get_rank
|
from khaosz.parallel.setup import get_current_device, get_world_size, get_rank
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from khaosz.trainer.train_callback import (
|
||||||
SchedulerCallback
|
SchedulerCallback
|
||||||
)
|
)
|
||||||
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
||||||
from khaosz.trainer.checkpoint import Checkpoint
|
from khaosz.data.checkpoint import Checkpoint
|
||||||
from khaosz.parallel.setup import spawn_parallel_fn
|
from khaosz.parallel.setup import spawn_parallel_fn
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
from pathlib import Path
|
||||||
|
import tempfile
|
||||||
|
import torch
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
from khaosz.data.checkpoint import Checkpoint
|
||||||
|
|
||||||
|
def test_single_process():
|
||||||
|
model = torch.nn.Linear(10, 5)
|
||||||
|
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||||
|
scheduler = CosineAnnealingLR(optimizer, T_max=10)
|
||||||
|
|
||||||
|
for epoch in range(3):
|
||||||
|
for iteration in range(10):
|
||||||
|
|
||||||
|
x = torch.randn(32, 10)
|
||||||
|
y = torch.randn(32, 5)
|
||||||
|
loss = model(x).mean()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
checkpoint = Checkpoint(
|
||||||
|
optimizer_state_dict=optimizer.state_dict(),
|
||||||
|
scheduler_state_dict=scheduler.state_dict(),
|
||||||
|
epoch=3,
|
||||||
|
iteration=30,
|
||||||
|
metrics={
|
||||||
|
"loss": [0.5, 0.4, 0.3, 0.2, 0.1],
|
||||||
|
"accuracy": [0.6, 0.7, 0.8, 0.85, 0.9]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
checkpoint.save(tmpdir, save_metric_plot=True)
|
||||||
|
|
||||||
|
loaded_checkpoint = Checkpoint.load(tmpdir)
|
||||||
|
|
||||||
|
assert loaded_checkpoint.epoch == 3
|
||||||
|
assert loaded_checkpoint.iteration == 30
|
||||||
|
assert loaded_checkpoint.metrics["loss"] == [0.5, 0.4, 0.3, 0.2, 0.1]
|
||||||
|
|
||||||
|
assert 'param_groups' in loaded_checkpoint.optimizer_state_dict
|
||||||
|
assert 'state' in loaded_checkpoint.optimizer_state_dict
|
||||||
|
|
||||||
|
png_files = list(Path(tmpdir).glob("*.png"))
|
||||||
|
assert png_files
|
||||||
|
|
||||||
|
def test_multi_process():
|
||||||
|
pass
|
||||||
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from khaosz.config import *
|
from khaosz.config import *
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
|
from khaosz.data.checkpoint import Checkpoint
|
||||||
|
|
||||||
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
"""Simulate early stopping behavior"""
|
"""Simulate early stopping behavior"""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue