feat(trainer): 重构检查点系统支持分布式训练
This commit is contained in:
parent
d21682f97a
commit
3d8047fa1b
|
|
@ -1,69 +1,104 @@
|
||||||
import os
|
import os
|
||||||
import pickle as pkl
|
import json
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional, Any
|
||||||
|
|
||||||
from torch import Tensor
|
import torch.distributed as dist
|
||||||
from typing import Dict, Optional
|
from torch.distributed.checkpoint import save, load
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank() -> int:
|
||||||
|
return dist.get_rank() if dist.is_initialized() else 0
|
||||||
|
|
||||||
|
|
||||||
class Checkpoint:
|
class Checkpoint:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
optimizer_state: Dict[str, Tensor],
|
optimizer_state_dict: Dict[str, Any],
|
||||||
scheduler_state: Dict[str, Tensor],
|
scheduler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
epoch: int = 0,
|
epoch: int = 0,
|
||||||
iteration: int = 0,
|
iteration: int = 0,
|
||||||
metrics: Optional[Dict[str, list]] = None,
|
metrics: Optional[Dict[str, list]] = None,
|
||||||
):
|
):
|
||||||
self.optimizer_state = optimizer_state
|
self.optimizer_state_dict = optimizer_state_dict
|
||||||
self.scheduler_state = scheduler_state
|
self.scheduler_state_dict = scheduler_state_dict
|
||||||
self.epoch, self.iteration = epoch, iteration
|
self.epoch = epoch
|
||||||
self.metrics = metrics
|
self.iteration = iteration
|
||||||
|
self.metrics = metrics or {}
|
||||||
|
|
||||||
def save(self, save_dir: str, save_metric_plot=True) -> None:
|
def save(
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
self,
|
||||||
|
save_dir: str,
|
||||||
|
save_metric_plot: bool = True,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
train_state = {
|
save_path = Path(save_dir)
|
||||||
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
rank = get_rank()
|
||||||
|
if rank == 0:
|
||||||
|
meta = {
|
||||||
"epoch": self.epoch,
|
"epoch": self.epoch,
|
||||||
"iteration": self.iteration,
|
"iteration": self.iteration,
|
||||||
"metrics": self.metrics,
|
"metrics": self.metrics,
|
||||||
"optimizer_state": self.optimizer_state,
|
|
||||||
"scheduler_state": self.scheduler_state,
|
|
||||||
}
|
}
|
||||||
|
with open(save_path / "meta.json", "w") as f:
|
||||||
with open(os.path.join(save_dir, "train_state.pkl"), "wb") as f:
|
json.dump(meta, f, indent=2)
|
||||||
pkl.dump(train_state, f)
|
|
||||||
|
|
||||||
if save_metric_plot and self.metrics:
|
if save_metric_plot and self.metrics:
|
||||||
self._plot_metrics(save_dir)
|
self._plot_metrics(str(save_path))
|
||||||
|
|
||||||
|
state_dict = {
|
||||||
|
"optimizer": self.optimizer_state_dict,
|
||||||
|
"scheduler": self.scheduler_state_dict
|
||||||
|
}
|
||||||
|
|
||||||
|
save(state_dict, checkpoint_id=str(save_path))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, save_dir: str) -> "Checkpoint":
|
def load(
|
||||||
checkpoint_path = os.path.join(save_dir, "train_state.pkl")
|
cls,
|
||||||
|
save_dir: str,
|
||||||
|
) -> "Checkpoint":
|
||||||
|
|
||||||
if not os.path.exists(checkpoint_path):
|
save_path = str(Path(save_dir))
|
||||||
raise FileNotFoundError(f"Checkpoint file {checkpoint_path} does not exist.")
|
rank = get_rank()
|
||||||
|
|
||||||
with open(checkpoint_path, "rb") as f:
|
meta = {}
|
||||||
train_state = pkl.load(f)
|
if rank == 0:
|
||||||
|
with open(Path(save_dir) / "meta.json", "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
meta_list = [meta]
|
||||||
|
dist.broadcast_object_list(meta_list, src=0)
|
||||||
|
meta = meta_list[0]
|
||||||
|
|
||||||
|
state_dict = {
|
||||||
|
"optimizer": {},
|
||||||
|
"scheduler": {}
|
||||||
|
}
|
||||||
|
load(state_dict, checkpoint_id=save_path, no_dist=True)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
optimizer_state=train_state["optimizer_state"],
|
optimizer_state_dict=state_dict["optimizer"],
|
||||||
scheduler_state=train_state["scheduler_state"],
|
scheduler_state_dict=state_dict["scheduler"],
|
||||||
epoch=train_state["epoch"],
|
epoch=meta["epoch"],
|
||||||
iteration=train_state["iteration"],
|
iteration=meta["iteration"],
|
||||||
metrics=train_state["metrics"]
|
metrics=meta.get("metrics", {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _plot_metrics(self, save_dir: str):
|
def _plot_metrics(self, save_dir: str):
|
||||||
for metric_name, metric_value in self.metrics.items():
|
for name, values in self.metrics.items():
|
||||||
|
if not values:
|
||||||
|
continue
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
plt.plot(metric_value, label=metric_name)
|
plt.plot(values, label=name)
|
||||||
plt.xlabel('Step')
|
plt.xlabel("Step")
|
||||||
plt.ylabel('Value')
|
plt.ylabel("Value")
|
||||||
|
plt.title(f"Training Metric: {name}")
|
||||||
plt.legend()
|
plt.legend()
|
||||||
plt.grid(True, alpha=0.3)
|
plt.grid(True, alpha=0.3)
|
||||||
|
plt.savefig(os.path.join(save_dir, f"{name}.png"), dpi=150, bbox_inches="tight")
|
||||||
save_path = os.path.join(save_dir, f"{metric_name}.png")
|
|
||||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
from typing import List, Optional, Protocol, TYPE_CHECKING
|
from typing import List, Literal, Optional, Protocol, TYPE_CHECKING
|
||||||
|
|
||||||
from khaosz.parallel import only_on_rank
|
from khaosz.parallel import only_on_rank
|
||||||
from khaosz.trainer.metric_util import (
|
from khaosz.trainer.metric_util import (
|
||||||
|
|
@ -104,15 +104,15 @@ class CheckpointCallback(TrainCallback):
|
||||||
|
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
@only_on_rank(0)
|
|
||||||
def _save_checkpoint(self, context: 'TrainContext'):
|
def _save_checkpoint(self, context: 'TrainContext'):
|
||||||
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
|
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
|
||||||
context.checkpoint = Checkpoint(
|
context.checkpoint = Checkpoint(
|
||||||
context.optimizer.state_dict(),
|
optimizer_state_dict=context.optimizer.state_dict(),
|
||||||
context.scheduler.state_dict(),
|
scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None,
|
||||||
context.epoch,
|
epoch=context.epoch,
|
||||||
context.iteration
|
iteration=context.iteration
|
||||||
)
|
)
|
||||||
|
|
||||||
context.checkpoint.save(save_path)
|
context.checkpoint.save(save_path)
|
||||||
self.last_ckpt_iter = context.iteration
|
self.last_ckpt_iter = context.iteration
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,18 +41,31 @@ class TrainContextBuilder:
|
||||||
rank=get_rank(),
|
rank=get_rank(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
device = get_current_device()
|
||||||
|
self._context.model = self._context.model.to(device=device)
|
||||||
|
|
||||||
|
if self.config.nprocs > 1:
|
||||||
|
|
||||||
|
fn = self.config.parallel_wrapper
|
||||||
|
optimizer_fn = self.config.optimizer_factory
|
||||||
|
scheduler_fn = self.config.scheduler_factory
|
||||||
|
|
||||||
|
self._context.model = fn(self._context.model)
|
||||||
|
self._context.optimizer = optimizer_fn(self._context.model.parameters())
|
||||||
|
self._context.scheduler = scheduler_fn(self._context.optimizer)
|
||||||
|
|
||||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||||
if checkpoint is None:
|
if checkpoint is None:
|
||||||
checkpoint = Checkpoint(
|
checkpoint = Checkpoint(
|
||||||
optimizer_state=self.config.optimizer.state_dict(),
|
optimizer_state_dict=self.config.optimizer.state_dict(),
|
||||||
scheduler_state=self.config.scheduler.state_dict(),
|
scheduler_state_dict=self.config.scheduler.state_dict() if self.config.scheduler is not None else None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# resume from the assigned checkpoint or assigned iteration
|
# resume from the assigned checkpoint or assigned iteration
|
||||||
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
||||||
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
||||||
self._context.optimizer.load_state_dict(checkpoint.optimizer_state)
|
self._context.optimizer.load_state_dict(checkpoint.optimizer_state_dict)
|
||||||
self._context.scheduler.load_state_dict(checkpoint.scheduler_state)
|
self._context.scheduler.load_state_dict(checkpoint.scheduler_state_dict)
|
||||||
|
|
||||||
self._context.checkpoint = checkpoint
|
self._context.checkpoint = checkpoint
|
||||||
return self
|
return self
|
||||||
|
|
@ -88,21 +101,6 @@ class TrainContextBuilder:
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_parallel(self) -> Self:
|
|
||||||
device = get_current_device()
|
|
||||||
self._context.model = self._context.model.to(device=device)
|
|
||||||
|
|
||||||
if self.config.nprocs > 1:
|
|
||||||
|
|
||||||
fn = self.config.parallel_wrapper
|
|
||||||
optimizer_fn = self.config.optimizer_factory
|
|
||||||
scheduler_fn = self.config.scheduler_factory
|
|
||||||
|
|
||||||
self._context.model = fn(self._context.model)
|
|
||||||
self._context.optimizer = optimizer_fn(self._context.model.parameters())
|
|
||||||
self._context.scheduler = scheduler_fn(self._context.optimizer)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def build(self) -> TrainContext:
|
def build(self) -> TrainContext:
|
||||||
return self._context
|
return self._context
|
||||||
|
|
@ -38,7 +38,6 @@ class Trainer:
|
||||||
.with_checkpoint(checkpoint)
|
.with_checkpoint(checkpoint)
|
||||||
.with_dataloader()
|
.with_dataloader()
|
||||||
.with_strategy()
|
.with_strategy()
|
||||||
.with_parallel()
|
|
||||||
.build())
|
.build())
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue