feat(trainer): 重构检查点系统支持分布式训练
This commit is contained in:
parent
d21682f97a
commit
3d8047fa1b
|
|
@ -1,69 +1,104 @@
|
|||
import os
|
||||
import pickle as pkl
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
from torch import Tensor
|
||||
from typing import Dict, Optional
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.checkpoint import save, load
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
return dist.get_rank() if dist.is_initialized() else 0
|
||||
|
||||
|
||||
class Checkpoint:
|
||||
def __init__(
|
||||
self,
|
||||
optimizer_state: Dict[str, Tensor],
|
||||
scheduler_state: Dict[str, Tensor],
|
||||
optimizer_state_dict: Dict[str, Any],
|
||||
scheduler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
epoch: int = 0,
|
||||
iteration: int = 0,
|
||||
metrics: Optional[Dict[str, list]] = None,
|
||||
):
|
||||
self.optimizer_state = optimizer_state
|
||||
self.scheduler_state = scheduler_state
|
||||
self.epoch, self.iteration = epoch, iteration
|
||||
self.metrics = metrics
|
||||
self.optimizer_state_dict = optimizer_state_dict
|
||||
self.scheduler_state_dict = scheduler_state_dict
|
||||
self.epoch = epoch
|
||||
self.iteration = iteration
|
||||
self.metrics = metrics or {}
|
||||
|
||||
def save(self, save_dir: str, save_metric_plot=True) -> None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
def save(
|
||||
self,
|
||||
save_dir: str,
|
||||
save_metric_plot: bool = True,
|
||||
) -> None:
|
||||
|
||||
train_state = {
|
||||
"epoch": self.epoch,
|
||||
"iteration": self.iteration,
|
||||
"metrics": self.metrics,
|
||||
"optimizer_state": self.optimizer_state,
|
||||
"scheduler_state": self.scheduler_state,
|
||||
save_path = Path(save_dir)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
rank = get_rank()
|
||||
if rank == 0:
|
||||
meta = {
|
||||
"epoch": self.epoch,
|
||||
"iteration": self.iteration,
|
||||
"metrics": self.metrics,
|
||||
}
|
||||
with open(save_path / "meta.json", "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
if save_metric_plot and self.metrics:
|
||||
self._plot_metrics(str(save_path))
|
||||
|
||||
state_dict = {
|
||||
"optimizer": self.optimizer_state_dict,
|
||||
"scheduler": self.scheduler_state_dict
|
||||
}
|
||||
|
||||
with open(os.path.join(save_dir, "train_state.pkl"), "wb") as f:
|
||||
pkl.dump(train_state, f)
|
||||
|
||||
if save_metric_plot and self.metrics:
|
||||
self._plot_metrics(save_dir)
|
||||
save(state_dict, checkpoint_id=str(save_path))
|
||||
|
||||
@classmethod
|
||||
def load(cls, save_dir: str) -> "Checkpoint":
|
||||
checkpoint_path = os.path.join(save_dir, "train_state.pkl")
|
||||
def load(
|
||||
cls,
|
||||
save_dir: str,
|
||||
) -> "Checkpoint":
|
||||
|
||||
if not os.path.exists(checkpoint_path):
|
||||
raise FileNotFoundError(f"Checkpoint file {checkpoint_path} does not exist.")
|
||||
save_path = str(Path(save_dir))
|
||||
rank = get_rank()
|
||||
|
||||
with open(checkpoint_path, "rb") as f:
|
||||
train_state = pkl.load(f)
|
||||
meta = {}
|
||||
if rank == 0:
|
||||
with open(Path(save_dir) / "meta.json", "r") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
if dist.is_initialized():
|
||||
meta_list = [meta]
|
||||
dist.broadcast_object_list(meta_list, src=0)
|
||||
meta = meta_list[0]
|
||||
|
||||
state_dict = {
|
||||
"optimizer": {},
|
||||
"scheduler": {}
|
||||
}
|
||||
load(state_dict, checkpoint_id=save_path, no_dist=True)
|
||||
|
||||
return cls(
|
||||
optimizer_state=train_state["optimizer_state"],
|
||||
scheduler_state=train_state["scheduler_state"],
|
||||
epoch=train_state["epoch"],
|
||||
iteration=train_state["iteration"],
|
||||
metrics=train_state["metrics"]
|
||||
optimizer_state_dict=state_dict["optimizer"],
|
||||
scheduler_state_dict=state_dict["scheduler"],
|
||||
epoch=meta["epoch"],
|
||||
iteration=meta["iteration"],
|
||||
metrics=meta.get("metrics", {}),
|
||||
)
|
||||
|
||||
def _plot_metrics(self, save_dir: str):
|
||||
for metric_name, metric_value in self.metrics.items():
|
||||
for name, values in self.metrics.items():
|
||||
if not values:
|
||||
continue
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(metric_value, label=metric_name)
|
||||
plt.xlabel('Step')
|
||||
plt.ylabel('Value')
|
||||
plt.plot(values, label=name)
|
||||
plt.xlabel("Step")
|
||||
plt.ylabel("Value")
|
||||
plt.title(f"Training Metric: {name}")
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
save_path = os.path.join(save_dir, f"{metric_name}.png")
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
plt.savefig(os.path.join(save_dir, f"{name}.png"), dpi=150, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||
from tqdm import tqdm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from typing import List, Optional, Protocol, TYPE_CHECKING
|
||||
from typing import List, Literal, Optional, Protocol, TYPE_CHECKING
|
||||
|
||||
from khaosz.parallel import only_on_rank
|
||||
from khaosz.trainer.metric_util import (
|
||||
|
|
@ -104,15 +104,15 @@ class CheckpointCallback(TrainCallback):
|
|||
|
||||
self.last_ckpt_iter = 0
|
||||
|
||||
@only_on_rank(0)
|
||||
def _save_checkpoint(self, context: 'TrainContext'):
|
||||
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
|
||||
context.checkpoint = Checkpoint(
|
||||
context.optimizer.state_dict(),
|
||||
context.scheduler.state_dict(),
|
||||
context.epoch,
|
||||
context.iteration
|
||||
optimizer_state_dict=context.optimizer.state_dict(),
|
||||
scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None,
|
||||
epoch=context.epoch,
|
||||
iteration=context.iteration
|
||||
)
|
||||
|
||||
context.checkpoint.save(save_path)
|
||||
self.last_ckpt_iter = context.iteration
|
||||
|
||||
|
|
|
|||
|
|
@ -41,18 +41,31 @@ class TrainContextBuilder:
|
|||
rank=get_rank(),
|
||||
)
|
||||
|
||||
device = get_current_device()
|
||||
self._context.model = self._context.model.to(device=device)
|
||||
|
||||
if self.config.nprocs > 1:
|
||||
|
||||
fn = self.config.parallel_wrapper
|
||||
optimizer_fn = self.config.optimizer_factory
|
||||
scheduler_fn = self.config.scheduler_factory
|
||||
|
||||
self._context.model = fn(self._context.model)
|
||||
self._context.optimizer = optimizer_fn(self._context.model.parameters())
|
||||
self._context.scheduler = scheduler_fn(self._context.optimizer)
|
||||
|
||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
if checkpoint is None:
|
||||
checkpoint = Checkpoint(
|
||||
optimizer_state=self.config.optimizer.state_dict(),
|
||||
scheduler_state=self.config.scheduler.state_dict(),
|
||||
optimizer_state_dict=self.config.optimizer.state_dict(),
|
||||
scheduler_state_dict=self.config.scheduler.state_dict() if self.config.scheduler is not None else None,
|
||||
)
|
||||
else:
|
||||
# resume from the assigned checkpoint or assigned iteration
|
||||
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
||||
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
||||
self._context.optimizer.load_state_dict(checkpoint.optimizer_state)
|
||||
self._context.scheduler.load_state_dict(checkpoint.scheduler_state)
|
||||
self._context.optimizer.load_state_dict(checkpoint.optimizer_state_dict)
|
||||
self._context.scheduler.load_state_dict(checkpoint.scheduler_state_dict)
|
||||
|
||||
self._context.checkpoint = checkpoint
|
||||
return self
|
||||
|
|
@ -88,21 +101,6 @@ class TrainContextBuilder:
|
|||
)
|
||||
return self
|
||||
|
||||
def with_parallel(self) -> Self:
|
||||
device = get_current_device()
|
||||
self._context.model = self._context.model.to(device=device)
|
||||
|
||||
if self.config.nprocs > 1:
|
||||
|
||||
fn = self.config.parallel_wrapper
|
||||
optimizer_fn = self.config.optimizer_factory
|
||||
scheduler_fn = self.config.scheduler_factory
|
||||
|
||||
self._context.model = fn(self._context.model)
|
||||
self._context.optimizer = optimizer_fn(self._context.model.parameters())
|
||||
self._context.scheduler = scheduler_fn(self._context.optimizer)
|
||||
|
||||
return self
|
||||
|
||||
def build(self) -> TrainContext:
|
||||
return self._context
|
||||
|
|
@ -38,7 +38,6 @@ class Trainer:
|
|||
.with_checkpoint(checkpoint)
|
||||
.with_dataloader()
|
||||
.with_strategy()
|
||||
.with_parallel()
|
||||
.build())
|
||||
|
||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||
|
|
|
|||
Loading…
Reference in New Issue