refactor(trainer): 优化trainer 结构
This commit is contained in:
parent
82e65ccc21
commit
c98b175cd5
|
|
@ -4,7 +4,6 @@ __author__ = "ViperEkura"
|
||||||
from khaosz.api import Khaosz
|
from khaosz.api import Khaosz
|
||||||
from khaosz.config import (
|
from khaosz.config import (
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
ParameterLoader,
|
|
||||||
TrainConfig,
|
TrainConfig,
|
||||||
)
|
)
|
||||||
from khaosz.model.transformer import Transformer
|
from khaosz.model.transformer import Transformer
|
||||||
|
|
@ -42,7 +41,6 @@ __all__ = [
|
||||||
"PriorityTextSplitter",
|
"PriorityTextSplitter",
|
||||||
|
|
||||||
"ModelConfig",
|
"ModelConfig",
|
||||||
"ParameterLoader",
|
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
|
|
||||||
"DatasetLoader",
|
"DatasetLoader",
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,13 @@ from khaosz.inference.generator import (
|
||||||
RetrievalGenerator,
|
RetrievalGenerator,
|
||||||
EmbeddingEncoder
|
EmbeddingEncoder
|
||||||
)
|
)
|
||||||
from khaosz.config.param_config import ParameterLoader
|
from khaosz.config.param_config import ModelParameter
|
||||||
|
|
||||||
|
|
||||||
class Khaosz:
|
class Khaosz:
|
||||||
def __init__(self, model_dir: str):
|
def __init__(self, model_dir: str):
|
||||||
self.parameter = ParameterLoader.load(model_dir)
|
self.parameter = ModelParameter()
|
||||||
|
self.parameter.load(model_dir)
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
self.parameter.to(*args, **kwargs)
|
self.parameter.to(*args, **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from khaosz.config.model_config import ModelConfig
|
from khaosz.config.model_config import ModelConfig
|
||||||
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
|
from khaosz.config.param_config import BaseModelIO, ModelParameter
|
||||||
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
|
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
|
||||||
from khaosz.config.train_config import TrainConfig
|
from khaosz.config.train_config import TrainConfig
|
||||||
|
|
||||||
|
|
@ -7,8 +7,6 @@ from khaosz.config.train_config import TrainConfig
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseModelIO",
|
"BaseModelIO",
|
||||||
"ModelParameter",
|
"ModelParameter",
|
||||||
"Checkpoint",
|
|
||||||
"ParameterLoader",
|
|
||||||
"ModelConfig",
|
"ModelConfig",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,8 @@
|
||||||
import pickle as pkl
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import safetensors.torch as st
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import safetensors.torch as st
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional, Self, Union
|
from typing import Optional, Self, Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from khaosz.data.tokenizer import BpeTokenizer
|
from khaosz.data.tokenizer import BpeTokenizer
|
||||||
|
|
@ -63,7 +60,7 @@ class BaseModelIO:
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to(self, *args, **kwargs) -> Self:
|
def to(self, *args, **kwargs) -> "BaseModelIO":
|
||||||
"""Move model to device."""
|
"""Move model to device."""
|
||||||
if self.model is not None:
|
if self.model is not None:
|
||||||
self.model.to(*args, **kwargs)
|
self.model.to(*args, **kwargs)
|
||||||
|
|
@ -77,142 +74,6 @@ class ModelParameter(BaseModelIO):
|
||||||
def save(self, save_dir: Union[str, Path]):
|
def save(self, save_dir: Union[str, Path]):
|
||||||
self.save_components(save_dir)
|
self.save_components(save_dir)
|
||||||
|
|
||||||
def load(self, load_dir: Union[str, Path]) -> Self:
|
def load(self, load_dir: Union[str, Path]) -> "ModelParameter":
|
||||||
return self.load_components(load_dir)
|
return self.load_components(load_dir)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Checkpoint(BaseModelIO):
|
|
||||||
"""Extended model parameters with training state."""
|
|
||||||
|
|
||||||
optimizer_state: Dict[str, Any] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Optimizer state."}
|
|
||||||
)
|
|
||||||
scheduler_state: Dict[str, Any] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Sampler state."}
|
|
||||||
)
|
|
||||||
loss_list: List[float] = field(
|
|
||||||
default_factory=list,
|
|
||||||
metadata={"help": "List of training losses."}
|
|
||||||
)
|
|
||||||
epoch: int = field(
|
|
||||||
default=0,
|
|
||||||
metadata={"help": "Current epoch."}
|
|
||||||
)
|
|
||||||
batch_iter: int = field(
|
|
||||||
default=0,
|
|
||||||
metadata={"help": "Current iteration."}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
|
||||||
dir_path = Path(directory)
|
|
||||||
return {
|
|
||||||
"loss_plot": dir_path / "loss_plot.png",
|
|
||||||
"training_state": dir_path / "training_state.pkl"
|
|
||||||
}
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"optimizer_state": self.optimizer_state,
|
|
||||||
"scheduler_state": self.scheduler_state,
|
|
||||||
"epoch": self.epoch,
|
|
||||||
"batch_iter": self.batch_iter,
|
|
||||||
"loss_list": self.loss_list,
|
|
||||||
}
|
|
||||||
|
|
||||||
def from_dict(self, data: Dict[str, Any]) -> Self:
|
|
||||||
self.optimizer_state = data["optimizer_state"]
|
|
||||||
self.scheduler_state = data["scheduler_state"]
|
|
||||||
self.epoch = data["epoch"]
|
|
||||||
self.batch_iter = data["batch_iter"]
|
|
||||||
self.loss_list = data["loss_list"]
|
|
||||||
|
|
||||||
def save_training_state(self, save_dir: Union[str, Path]):
|
|
||||||
paths = self._get_training_paths(save_dir)
|
|
||||||
|
|
||||||
# Save loss plot
|
|
||||||
self._plot_loss(str(paths["loss_plot"]))
|
|
||||||
|
|
||||||
# Save training state
|
|
||||||
with open(str(paths["training_state"]), "wb") as f:
|
|
||||||
pkl.dump(self.to_dict(), f)
|
|
||||||
|
|
||||||
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
|
|
||||||
paths = self._get_training_paths(load_dir)
|
|
||||||
|
|
||||||
# Load training state
|
|
||||||
with open(str(paths["training_state"]), "rb") as f:
|
|
||||||
train_state = pkl.load(f)
|
|
||||||
|
|
||||||
self.from_dict(train_state)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def _plot_loss(self, save_path: str):
|
|
||||||
"""Plot and save loss curve."""
|
|
||||||
if not self.loss_list:
|
|
||||||
return
|
|
||||||
|
|
||||||
batch_iter = len(self.loss_list)
|
|
||||||
|
|
||||||
plt.figure(figsize=(10, 6))
|
|
||||||
plt.plot(self.loss_list)
|
|
||||||
plt.title(f"Training Loss - Iteration {batch_iter}")
|
|
||||||
plt.xlabel("Batch")
|
|
||||||
plt.ylabel("Loss")
|
|
||||||
plt.grid(True)
|
|
||||||
plt.savefig(save_path, dpi=30, bbox_inches="tight")
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
def save(self, save_dir: Union[str, Path]):
|
|
||||||
"""Save complete checkpoint."""
|
|
||||||
self.save_components(save_dir)
|
|
||||||
self.save_training_state(save_dir)
|
|
||||||
|
|
||||||
def load(self, load_dir: Union[str, Path]) -> Self:
|
|
||||||
"""Load complete checkpoint."""
|
|
||||||
self.load_components(load_dir)
|
|
||||||
self.load_training_state(load_dir)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class ParameterLoader:
|
|
||||||
"""Factory class for loading model parameters or checkpoints."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load(load_dir: Union[str, Path]) -> Union[ModelParameter, Checkpoint]:
|
|
||||||
"""Load either ModelParameter or Checkpoint based on directory contents."""
|
|
||||||
load_dir = Path(load_dir)
|
|
||||||
|
|
||||||
# Check for training-specific files
|
|
||||||
loss_file = load_dir / "loss.pkl"
|
|
||||||
has_training_data = loss_file.exists()
|
|
||||||
|
|
||||||
# Create appropriate instance
|
|
||||||
if has_training_data:
|
|
||||||
checkpoint = Checkpoint()
|
|
||||||
checkpoint.load(str(load_dir))
|
|
||||||
return checkpoint
|
|
||||||
else:
|
|
||||||
params = ModelParameter()
|
|
||||||
params.load(str(load_dir))
|
|
||||||
return params
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_checkpoint(
|
|
||||||
model: nn.Module,
|
|
||||||
tokenizer: BpeTokenizer,
|
|
||||||
config: ModelConfig,
|
|
||||||
loss_list: Optional[list[float]] = None,
|
|
||||||
optimizer: Optional[optim.Optimizer] = None,
|
|
||||||
) -> Checkpoint:
|
|
||||||
"""Convenience method to create a training checkpoint."""
|
|
||||||
return Checkpoint(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
config=config,
|
|
||||||
loss_list=loss_list or [],
|
|
||||||
optimizer_state=optimizer
|
|
||||||
)
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Literal, Dict
|
from typing import Any, Dict
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
@ -39,7 +39,10 @@ class CosineScheduleConfig(ScheduleConfig):
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Total training steps for cosine schedule."}
|
metadata={"help": "Total training steps for cosine schedule."}
|
||||||
)
|
)
|
||||||
schedule_type: Literal["cosine"] = "cosine"
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.schedule_type = "cosine"
|
||||||
|
self.validate()
|
||||||
|
|
||||||
def get_kwargs(self) -> Dict[str, Any]:
|
def get_kwargs(self) -> Dict[str, Any]:
|
||||||
if self.total_steps is None:
|
if self.total_steps is None:
|
||||||
|
|
@ -68,7 +71,10 @@ class SGDRScheduleConfig(ScheduleConfig):
|
||||||
default=2,
|
default=2,
|
||||||
metadata={"help": "Multiplier for cycle length growth."}
|
metadata={"help": "Multiplier for cycle length growth."}
|
||||||
)
|
)
|
||||||
schedule_type: Literal["sgdr"] = "sgdr"
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.schedule_type = "sgdr"
|
||||||
|
self.validate()
|
||||||
|
|
||||||
def get_kwargs(self) -> Dict[str, Any]:
|
def get_kwargs(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,20 @@
|
||||||
from dataclasses import dataclass, field
|
from torch import nn
|
||||||
from typing import Optional, TYPE_CHECKING
|
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
from dataclasses import dataclass, field
|
||||||
from khaosz.trainer.strategy import BaseStrategy
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainConfig:
|
class TrainConfig:
|
||||||
strategy: "BaseStrategy" = field(
|
# basic setting
|
||||||
|
model: nn.Module = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Model for training."}
|
||||||
|
)
|
||||||
|
strategy: str = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Training strategy."}
|
metadata={"help": "Training strategy."}
|
||||||
)
|
)
|
||||||
|
|
@ -21,9 +26,9 @@ class TrainConfig:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Optimizer for training."}
|
metadata={"help": "Optimizer for training."}
|
||||||
)
|
)
|
||||||
checkpoint_dir: str = field(
|
scheduler: LRScheduler = field(
|
||||||
default="./checkpoint",
|
default=None,
|
||||||
metadata={"help": "Checkpoint directory."}
|
metadata={"help": "Scheduler for training."}
|
||||||
)
|
)
|
||||||
n_epoch: int = field(
|
n_epoch: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
|
|
@ -33,6 +38,16 @@ class TrainConfig:
|
||||||
default=4,
|
default=4,
|
||||||
metadata={"help": "Batch size for training."}
|
metadata={"help": "Batch size for training."}
|
||||||
)
|
)
|
||||||
|
accumulation_steps: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of iterations between steps."}
|
||||||
|
)
|
||||||
|
max_grad_norm: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "Maximum gradient norm."}
|
||||||
|
)
|
||||||
|
|
||||||
|
# checkpoint setting
|
||||||
start_epoch: int = field(
|
start_epoch: int = field(
|
||||||
default=0,
|
default=0,
|
||||||
metadata={"help": "Start epoch for training."}
|
metadata={"help": "Start epoch for training."}
|
||||||
|
|
@ -41,18 +56,14 @@ class TrainConfig:
|
||||||
default=0,
|
default=0,
|
||||||
metadata={"help": "Start batch iteration for training."}
|
metadata={"help": "Start batch iteration for training."}
|
||||||
)
|
)
|
||||||
|
checkpoint_dir: str = field(
|
||||||
|
default="./checkpoint",
|
||||||
|
metadata={"help": "Checkpoint directory."}
|
||||||
|
)
|
||||||
checkpoint_interval: int = field(
|
checkpoint_interval: int = field(
|
||||||
default=5000,
|
default=5000,
|
||||||
metadata={"help": "Number of iterations between checkpoints."}
|
metadata={"help": "Number of iterations between checkpoints."}
|
||||||
)
|
)
|
||||||
accumulation_steps: int = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "Number of iterations between steps."}
|
|
||||||
)
|
|
||||||
max_grad_norm: float = field(
|
|
||||||
default=1.0,
|
|
||||||
metadata={"help": "Maximum gradient norm."}
|
|
||||||
)
|
|
||||||
|
|
||||||
# dataloader setting
|
# dataloader setting
|
||||||
random_seed: int = field(
|
random_seed: int = field(
|
||||||
|
|
@ -77,3 +88,9 @@ class TrainConfig:
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "Number of processes for distributed training."}
|
metadata={"help": "Number of processes for distributed training."}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# others
|
||||||
|
kwargs: dict = field(
|
||||||
|
default_factory=dict,
|
||||||
|
metadata={"help": "Other arguments."}
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,65 @@
|
||||||
|
import os
|
||||||
|
import pickle as pkl
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class Checkpoint:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer_state: Optimizer,
|
||||||
|
scheduler_state: LRScheduler,
|
||||||
|
epoch: int = 0,
|
||||||
|
iteration: int = 0,
|
||||||
|
metrics: Optional[Dict[str, list]] = None,
|
||||||
|
):
|
||||||
|
self.optimizer_state = optimizer_state
|
||||||
|
self.scheduler_state = scheduler_state
|
||||||
|
self.epoch, self.iteration = epoch, iteration
|
||||||
|
self.metrics = metrics
|
||||||
|
|
||||||
|
def save(self, save_dir: str, save_metric_plot=True) -> None:
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
train_state = {
|
||||||
|
"epoch": self.epoch,
|
||||||
|
"iteration": self.iteration,
|
||||||
|
"metrics": self.metrics,
|
||||||
|
"optimizer_state": self.optimizer_state,
|
||||||
|
"scheduler_state": self.scheduler_state,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(os.path.join(save_dir, "train_state.pkl"), "wb") as f:
|
||||||
|
pkl.dump(train_state, f)
|
||||||
|
|
||||||
|
if save_metric_plot and self.metrics:
|
||||||
|
self._plot_metrics()
|
||||||
|
|
||||||
|
def load(self, save_dir: str) -> "Checkpoint":
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
raise FileNotFoundError(f"Checkpoint directory {save_dir} does not exist.")
|
||||||
|
|
||||||
|
with open(os.path.join(save_dir, "train_state.pkl"), "rb") as f:
|
||||||
|
train_state = pkl.load(f)
|
||||||
|
self.epoch = train_state["epoch"]
|
||||||
|
self.iteration = train_state["iteration"]
|
||||||
|
self.metrics = train_state["metrics"]
|
||||||
|
self.optimizer_state = train_state["optimizer_state"]
|
||||||
|
self.scheduler_state = train_state["scheduler_state"]
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _plot_metrics(self):
|
||||||
|
for metric_name, metric_value in self.metrics.items():
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.plot(metric_value, label=metric_name)
|
||||||
|
plt.xlabel('Step')
|
||||||
|
plt.ylabel('Value')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
plt.savefig(f'{metric_name}.png', dpi=150, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
@ -151,7 +151,7 @@ class SchedulerFactory:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_scheduler(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler:
|
def load(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler:
|
||||||
kwargs = scedule_config.get_kwargs()
|
kwargs = scedule_config.get_kwargs()
|
||||||
schedule_type = kwargs.pop("schedule_type")
|
schedule_type = kwargs.pop("schedule_type")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Any, Tuple, Callable, Dict, Union
|
from typing import Any, Callable, Dict, Union
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -41,7 +41,7 @@ class BaseStrategy(ABC):
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
return self.compute_loss(batch)
|
return self.compute_loss(batch)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -94,7 +94,7 @@ class DpoStrategy(BaseStrategy):
|
||||||
self.pad_token_id = pad_token_id
|
self.pad_token_id = pad_token_id
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
|
|
||||||
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
batch = move_to_device(batch, self.device)
|
batch = move_to_device(batch, self.device)
|
||||||
good_ids, bad_ids = batch["chosen"], batch["rejected"]
|
good_ids, bad_ids = batch["chosen"], batch["rejected"]
|
||||||
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
|
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
|
||||||
|
|
@ -115,41 +115,6 @@ class DpoStrategy(BaseStrategy):
|
||||||
return dpo_loss
|
return dpo_loss
|
||||||
|
|
||||||
|
|
||||||
class PpoStrategy(BaseStrategy):
|
|
||||||
def __init__(self, model, pad_token_id, epsilon):
|
|
||||||
super().__init__(model)
|
|
||||||
ref_model = copy.deepcopy(self.model)
|
|
||||||
ref_model.requires_grad_(False)
|
|
||||||
ref_model.eval()
|
|
||||||
|
|
||||||
self.ref_model = ref_model
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.epsilon = epsilon
|
|
||||||
|
|
||||||
def ppo_clip_loss_masked(
|
|
||||||
self,
|
|
||||||
log_probs: Tensor,
|
|
||||||
old_log_probs: Tensor,
|
|
||||||
advantages: Tensor,
|
|
||||||
values: Tensor,
|
|
||||||
returns: Tensor,
|
|
||||||
mask: Tensor,
|
|
||||||
clip_eps: float=0.2,
|
|
||||||
):
|
|
||||||
ratio = torch.exp(log_probs - old_log_probs)
|
|
||||||
surr1 = ratio * advantages
|
|
||||||
surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
|
|
||||||
policy_loss = -torch.min(surr1, surr2).masked_select(mask).mean()
|
|
||||||
|
|
||||||
value_loss = F.mse_loss(values.masked_select(mask),
|
|
||||||
returns.masked_select(mask))
|
|
||||||
|
|
||||||
entropy = -(log_probs.exp() * log_probs).masked_select(mask).mean()
|
|
||||||
entropy_loss = -entropy
|
|
||||||
return policy_loss, value_loss, entropy_loss
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class StrategyFactory:
|
class StrategyFactory:
|
||||||
|
|
||||||
def load(model, train_type, device, **kwargs):
|
def load(model, train_type, device, **kwargs):
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,9 @@ import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
from typing import List, Optional, Protocol, TYPE_CHECKING
|
from typing import List, Optional, Protocol, TYPE_CHECKING
|
||||||
|
|
||||||
from khaosz.config import ScheduleConfig
|
|
||||||
from khaosz.trainer.metric_util import (
|
from khaosz.trainer.metric_util import (
|
||||||
grad_max,
|
grad_max,
|
||||||
grad_min,
|
grad_min,
|
||||||
|
|
@ -17,9 +16,9 @@ from khaosz.trainer.metric_util import (
|
||||||
grad_std,
|
grad_std,
|
||||||
grad_nan_num
|
grad_nan_num
|
||||||
)
|
)
|
||||||
|
from khaosz.trainer.checkpoint import Checkpoint
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from khaosz.trainer.trainer import Trainer
|
|
||||||
from khaosz.trainer.train_context import TrainContext
|
from khaosz.trainer.train_context import TrainContext
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -28,31 +27,31 @@ class TrainCallback(Protocol):
|
||||||
Callback interface for trainer.
|
Callback interface for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_train_begin(self, context: 'TrainContext'):
|
||||||
""" Called at the beginning of training. """
|
""" Called at the beginning of training. """
|
||||||
|
|
||||||
def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_train_end(self, context: 'TrainContext'):
|
||||||
""" Called at the end of training. """
|
""" Called at the end of training. """
|
||||||
|
|
||||||
def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_epoch_begin(self, context: 'TrainContext'):
|
||||||
""" Called at the beginning of each epoch. """
|
""" Called at the beginning of each epoch. """
|
||||||
|
|
||||||
def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_epoch_end(self, context: 'TrainContext'):
|
||||||
""" Called at the end of each epoch. """
|
""" Called at the end of each epoch. """
|
||||||
|
|
||||||
def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_step_begin(self, context: 'TrainContext'):
|
||||||
""" Called at the beginning of each step. """
|
""" Called at the beginning of each step. """
|
||||||
|
|
||||||
def on_step_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_step_end(self, context: 'TrainContext'):
|
||||||
""" Called at the end of each step."""
|
""" Called at the end of each step."""
|
||||||
|
|
||||||
def on_batch_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_batch_begin(self, context: 'TrainContext'):
|
||||||
""" Called at the beginning of each batch. """
|
""" Called at the beginning of each batch. """
|
||||||
|
|
||||||
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_batch_end(self, context: 'TrainContext'):
|
||||||
""" Called at the end of each batch. """
|
""" Called at the end of each batch. """
|
||||||
|
|
||||||
def on_error(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_error(self, context: 'TrainContext'):
|
||||||
""" Called when an error occurs during training. """
|
""" Called when an error occurs during training. """
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -63,29 +62,27 @@ class GradientClippingCallback(TrainCallback):
|
||||||
def __init__(self, max_grad_norm: float):
|
def __init__(self, max_grad_norm: float):
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_step_begin(self, context: 'TrainContext'):
|
||||||
_ = context
|
_ = context
|
||||||
clip_grad_norm_(trainer.parameter.model.parameters(), self.max_grad_norm)
|
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
|
||||||
class SchedulerCallback(TrainCallback):
|
class SchedulerCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Scheduler callback for trainer.
|
Scheduler callback for trainer.
|
||||||
"""
|
"""
|
||||||
def __init__(self, schedule_config: ScheduleConfig):
|
def __init__(self, scheduler: LRScheduler):
|
||||||
self.schedule_config = schedule_config
|
self.scheduler: LRScheduler = scheduler
|
||||||
self.scheduler: Optional[LambdaLR] = None
|
|
||||||
|
|
||||||
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_train_begin(self, context: 'TrainContext'):
|
||||||
|
for group in context.optimizer.param_groups:
|
||||||
for group in trainer.train_config.optimizer.param_groups:
|
|
||||||
if "initial_lr" not in group:
|
if "initial_lr" not in group:
|
||||||
group["initial_lr"] = group["lr"]
|
group["initial_lr"] = group["lr"]
|
||||||
|
|
||||||
self.scheduler = context.scheduler
|
self.scheduler = context.scheduler
|
||||||
|
|
||||||
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_batch_end(self, context: 'TrainContext'):
|
||||||
_ = trainer, context
|
_ = context
|
||||||
if self.scheduler:
|
if self.scheduler:
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
|
|
||||||
|
|
@ -94,54 +91,59 @@ class CheckpointCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Checkpoint callback for trainer.
|
Checkpoint callback for trainer.
|
||||||
"""
|
"""
|
||||||
def __init__(self, checkpoint_interval: int):
|
def __init__(self, interval: int, save_dir: str):
|
||||||
self.checkpoint_interval = checkpoint_interval
|
self.interval = interval
|
||||||
|
self.save_dir = save_dir
|
||||||
|
self.checkpoint = None
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'):
|
def _save_checkpoint(self, context: 'TrainContext'):
|
||||||
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.batch_iter}")
|
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}")
|
||||||
context.checkpoint.optimizer_state = context.optimizer.state_dict()
|
self.checkpoint = Checkpoint(
|
||||||
context.checkpoint.scheduler_state = context.scheduler.state_dict()
|
context.optimizer.state_dict(),
|
||||||
context.checkpoint.epoch = context.epoch
|
context.scheduler.state_dict(),
|
||||||
context.checkpoint.batch_iter = context.batch_iter
|
context.epoch,
|
||||||
context.checkpoint.save(save_path)
|
context.iteration
|
||||||
self.last_ckpt_iter = context.batch_iter
|
)
|
||||||
|
self.checkpoint.save(save_path)
|
||||||
|
self.last_ckpt_iter = context.iteration
|
||||||
|
|
||||||
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_batch_end(self, context: 'TrainContext'):
|
||||||
context.checkpoint.loss_list.append(context.loss)
|
if context.iteration - self.last_ckpt_iter >= self.interval:
|
||||||
|
self._save_checkpoint(context)
|
||||||
|
|
||||||
if context.batch_iter - self.last_ckpt_iter >= self.checkpoint_interval:
|
def on_train_end(self, context: 'TrainContext'):
|
||||||
self._save_checkpoint(trainer, context)
|
if context.iteration != self.last_ckpt_iter:
|
||||||
|
self._save_checkpoint(context)
|
||||||
|
|
||||||
def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_error(self, context: 'TrainContext'):
|
||||||
if context.batch_iter != self.last_ckpt_iter:
|
self._save_checkpoint(context)
|
||||||
self._save_checkpoint(trainer, context)
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressBarCallback(TrainCallback):
|
class ProgressBarCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Progress bar callback for trainer.
|
Progress bar callback for trainer.
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self, num_epoch: int):
|
||||||
|
self.num_epoch = num_epoch
|
||||||
self.progress_bar: tqdm = None
|
self.progress_bar: tqdm = None
|
||||||
|
|
||||||
def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_epoch_begin(self, context: 'TrainContext'):
|
||||||
self.progress_bar = tqdm(
|
self.progress_bar = tqdm(
|
||||||
context.dataloader,
|
context.dataloader,
|
||||||
desc=f"Epoch {context.epoch+1}/{trainer.train_config.n_epoch}",
|
desc=f"Epoch {context.epoch+1}/{self.num_epoch}",
|
||||||
dynamic_ncols=True
|
dynamic_ncols=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_batch_end(self, context: 'TrainContext'):
|
||||||
_ = trainer
|
|
||||||
self.progress_bar.set_postfix({
|
self.progress_bar.set_postfix({
|
||||||
"loss": f"{context.loss:.4f}",
|
"loss": f"{context.loss:.4f}",
|
||||||
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}"
|
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}"
|
||||||
})
|
})
|
||||||
self.progress_bar.update(1)
|
self.progress_bar.update(1)
|
||||||
|
|
||||||
def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_epoch_end(self, context: 'TrainContext'):
|
||||||
_ = trainer, context
|
_ = context
|
||||||
if self.progress_bar:
|
if self.progress_bar:
|
||||||
self.progress_bar.close()
|
self.progress_bar.close()
|
||||||
|
|
||||||
|
|
@ -177,13 +179,13 @@ class StepMonitorCallback(TrainCallback):
|
||||||
|
|
||||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def _handle_info(self, trainer: 'Trainer', context: 'TrainContext'):
|
def _handle_info(self, context: 'TrainContext'):
|
||||||
""" Logs training information to console and file. """
|
""" Logs training information to console and file. """
|
||||||
|
|
||||||
log_data = {
|
log_data = {
|
||||||
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
|
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||||
"epoch": context.epoch,
|
"epoch": context.epoch,
|
||||||
"iter": context.batch_iter,
|
"iter": context.iteration,
|
||||||
"metrics": self.metrics,
|
"metrics": self.metrics,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -193,34 +195,34 @@ class StepMonitorCallback(TrainCallback):
|
||||||
elif metric == 'lr':
|
elif metric == 'lr':
|
||||||
log_data[metric] = context.optimizer.param_groups[-1]['lr']
|
log_data[metric] = context.optimizer.param_groups[-1]['lr']
|
||||||
elif metric == 'grad_norm':
|
elif metric == 'grad_norm':
|
||||||
log_data[metric] = grad_norm(trainer.parameter.model)
|
log_data[metric] = grad_norm(context.model)
|
||||||
elif metric == 'grad_std':
|
elif metric == 'grad_std':
|
||||||
log_data[metric] = grad_std(trainer.parameter.model)
|
log_data[metric] = grad_std(context.model)
|
||||||
elif metric == 'grad_max':
|
elif metric == 'grad_max':
|
||||||
log_data[metric] = grad_max(trainer.parameter.model)
|
log_data[metric] = grad_max(context.model)
|
||||||
elif metric == 'grad_min':
|
elif metric == 'grad_min':
|
||||||
log_data[metric] = grad_min(trainer.parameter.model)
|
log_data[metric] = grad_min(context.model)
|
||||||
elif metric == 'grad_mean':
|
elif metric == 'grad_mean':
|
||||||
log_data[metric] = grad_mean(trainer.parameter.model)
|
log_data[metric] = grad_mean(context.model)
|
||||||
elif metric == 'grad_nan_num':
|
elif metric == 'grad_nan_num':
|
||||||
log_data[metric] = grad_nan_num(trainer.parameter.model)
|
log_data[metric] = grad_nan_num(context.model)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid metric: {metric}")
|
raise ValueError(f"Invalid metric: {metric}")
|
||||||
|
|
||||||
return log_data
|
return log_data
|
||||||
|
|
||||||
def _handle_log(self, trainer: 'Trainer', context: 'TrainContext'):
|
def _handle_log(self, context: 'TrainContext'):
|
||||||
""" Logs training information to console and file. """
|
""" Logs training information to console and file. """
|
||||||
log_data = self._handle_info(trainer, context)
|
log_data = self._handle_info(context)
|
||||||
try:
|
try:
|
||||||
log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.batch_iter}.json"
|
log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.iteration}.json"
|
||||||
with open(log_file, 'a') as f:
|
with open(log_file, 'a') as f:
|
||||||
json.dump(log_data, f, indent=4)
|
json.dump(log_data, f, indent=4)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def on_step_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_step_end(self, context: 'TrainContext'):
|
||||||
if self.step_num % self.log_interval == 0:
|
if self.step_num % self.log_interval == 0:
|
||||||
self._handle_log(trainer, context)
|
self._handle_log(context)
|
||||||
|
|
||||||
self.step_num += 1
|
self.step_num += 1
|
||||||
|
|
@ -1,61 +1,60 @@
|
||||||
from dataclasses import dataclass, field, fields
|
import torch.nn as nn
|
||||||
from typing import Optional, Self, TYPE_CHECKING
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from khaosz.config import Checkpoint
|
|
||||||
from khaosz.data import ResumableDistributedSampler
|
|
||||||
from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory
|
|
||||||
from khaosz.parallel.utils import get_world_size, get_rank
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
from khaosz.data import ResumableDistributedSampler
|
||||||
from khaosz.trainer.trainer import Trainer
|
from khaosz.trainer.checkpoint import Checkpoint
|
||||||
|
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
|
||||||
|
from khaosz.config.train_config import TrainConfig
|
||||||
|
from khaosz.parallel.utils import get_current_device, get_world_size, get_rank
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional, Self
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainContext:
|
class TrainContext:
|
||||||
|
model: nn.Module = field(default=None)
|
||||||
|
strategy: BaseStrategy = field(default=None)
|
||||||
dataloader: DataLoader = field(default=None)
|
dataloader: DataLoader = field(default=None)
|
||||||
optimizer: Optimizer = field(default=None)
|
optimizer: Optimizer = field(default=None)
|
||||||
scheduler: BaseScheduler = field(default=None)
|
scheduler: LRScheduler = field(default=None)
|
||||||
checkpoint: Checkpoint = field(default=None)
|
checkpoint: Checkpoint = field(default=None)
|
||||||
|
|
||||||
epoch: int = field(default=0)
|
epoch: int = field(default=0)
|
||||||
batch_iter: int = field(default=0)
|
iteration: int = field(default=0)
|
||||||
loss: float = field(default=0.0)
|
loss: float = field(default=0.0)
|
||||||
|
|
||||||
wolrd_size: int = field(default=1)
|
wolrd_size: int = field(default=1)
|
||||||
rank: int = field(default=0)
|
rank: int = field(default=0)
|
||||||
|
|
||||||
def asdict(self) -> dict:
|
|
||||||
return {field.name: getattr(self, field.name)
|
|
||||||
for field in fields(self)}
|
|
||||||
|
|
||||||
|
|
||||||
class TrainContextBuilder:
|
class TrainContextBuilder:
|
||||||
def __init__(self, trainer: 'Trainer'):
|
def __init__(self, config: TrainConfig):
|
||||||
self.trainer = trainer
|
self.config = config
|
||||||
self._context: TrainContext = None
|
self._context: TrainContext = None
|
||||||
|
|
||||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||||
self._context = TrainContext()
|
self._context = TrainContext()
|
||||||
if checkpoint is None:
|
if checkpoint is None:
|
||||||
checkpoint = Checkpoint(
|
checkpoint = Checkpoint(
|
||||||
model=self.trainer.parameter.model,
|
optimizer_state=self.config.optimizer.state_dict(),
|
||||||
tokenizer=self.trainer.parameter.tokenizer,
|
scheduler_state=self.config.scheduler.state_dict(),
|
||||||
config=self.trainer.parameter.config,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# resume from the assigned checkpoint or assigned iteration
|
# resume from the assigned checkpoint or assigned iteration
|
||||||
self._context.epoch = max(checkpoint.epoch, self.trainer.train_config.start_epoch)
|
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
||||||
self._context.batch_iter = max(checkpoint.batch_iter, self.trainer.train_config.start_batch)
|
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
||||||
|
|
||||||
self._context.checkpoint = checkpoint
|
self._context.checkpoint = checkpoint
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_optimizer(self) -> Self:
|
def with_optimizer(self) -> Self:
|
||||||
|
optimizer = self.config.optimizer
|
||||||
if self._context is None:
|
if self._context is None:
|
||||||
raise RuntimeError("Must call with_checkpoint() before with_optimizer()")
|
raise RuntimeError("Must call with_checkpoint() before with_optimizer()")
|
||||||
|
|
||||||
optimizer = self.trainer.train_config.optimizer
|
|
||||||
|
|
||||||
if self._context.checkpoint and self._context.checkpoint.optimizer_state:
|
if self._context.checkpoint and self._context.checkpoint.optimizer_state:
|
||||||
optimizer.load_state_dict(self._context.checkpoint.optimizer_state)
|
optimizer.load_state_dict(self._context.checkpoint.optimizer_state)
|
||||||
|
|
||||||
|
|
@ -67,13 +66,7 @@ class TrainContextBuilder:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_scheduler(self) -> Self:
|
def with_scheduler(self) -> Self:
|
||||||
if not hasattr(self._context, 'optimizer') or self._context.optimizer is None:
|
scheduler = self.config.scheduler
|
||||||
raise RuntimeError("Must call with_optimizer() before with_scheduler()")
|
|
||||||
|
|
||||||
optimizer = self.trainer.train_config.optimizer
|
|
||||||
schedule_config = self.trainer.schedule_config
|
|
||||||
scheduler = SchedulerFactory.load_scheduler(optimizer, schedule_config)
|
|
||||||
|
|
||||||
if self._context.checkpoint and self._context.checkpoint.scheduler_state:
|
if self._context.checkpoint and self._context.checkpoint.scheduler_state:
|
||||||
scheduler.load_state_dict(self._context.checkpoint.scheduler_state)
|
scheduler.load_state_dict(self._context.checkpoint.scheduler_state)
|
||||||
|
|
||||||
|
|
@ -85,29 +78,41 @@ class TrainContextBuilder:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_dataloader(self) -> Self:
|
def with_dataloader(self) -> Self:
|
||||||
# fix: change batch level batch_iter to sample level offset
|
# fix: change batch level iteration to sample level offset
|
||||||
sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size
|
config = self.config
|
||||||
|
sampler_offset = self._context.iteration * config.batch_size
|
||||||
resumeable_sampler = ResumableDistributedSampler(
|
resumeable_sampler = ResumableDistributedSampler(
|
||||||
data_source=self.trainer.train_config.dataset,
|
data_source=config.dataset,
|
||||||
start_epoch=self._context.epoch,
|
start_epoch=self._context.epoch,
|
||||||
start_iter=sampler_offset,
|
start_iter=sampler_offset,
|
||||||
seed=self.trainer.train_config.random_seed
|
seed=config.random_seed
|
||||||
)
|
)
|
||||||
|
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
self.trainer.train_config.dataset,
|
config.dataset,
|
||||||
batch_size=self.trainer.train_config.batch_size,
|
batch_size=config.batch_size,
|
||||||
sampler=resumeable_sampler,
|
sampler=resumeable_sampler,
|
||||||
num_workers=self.trainer.train_config.num_workers,
|
num_workers=config.num_workers,
|
||||||
pin_memory=self.trainer.train_config.pin_memory,
|
pin_memory=config.pin_memory,
|
||||||
prefetch_factor=self.trainer.train_config.prefetch_factor
|
prefetch_factor=config.prefetch_factor
|
||||||
)
|
)
|
||||||
self._context.dataloader = dataloader
|
self._context.dataloader = dataloader
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def build(self) -> TrainContext:
|
def with_strategy(self) -> Self:
|
||||||
|
device = get_current_device()
|
||||||
|
self._context.strategy = StrategyFactory.load(
|
||||||
|
model=self.config.model,
|
||||||
|
train_type=self.config.strategy,
|
||||||
|
device=device,
|
||||||
|
**self.config.kwargs
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
if self.trainer.train_config.nprocs > 1:
|
def build(self) -> TrainContext:
|
||||||
|
self._context.model = self.config.model
|
||||||
|
|
||||||
|
if self.config.nprocs > 1:
|
||||||
self._context.wolrd_size = get_world_size()
|
self._context.wolrd_size = get_world_size()
|
||||||
self._context.rank = get_rank()
|
self._context.rank = get_rank()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from khaosz.config import (
|
from khaosz.config import TrainConfig
|
||||||
ModelParameter,
|
|
||||||
Checkpoint,
|
|
||||||
ScheduleConfig,
|
|
||||||
TrainConfig
|
|
||||||
)
|
|
||||||
from khaosz.trainer.train_callback import (
|
from khaosz.trainer.train_callback import (
|
||||||
TrainCallback,
|
TrainCallback,
|
||||||
ProgressBarCallback,
|
ProgressBarCallback,
|
||||||
|
|
@ -13,7 +8,7 @@ from khaosz.trainer.train_callback import (
|
||||||
GradientClippingCallback,
|
GradientClippingCallback,
|
||||||
SchedulerCallback
|
SchedulerCallback
|
||||||
)
|
)
|
||||||
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder, Checkpoint
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -21,66 +16,65 @@ logger = logging.getLogger(__name__)
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parameter: ModelParameter,
|
|
||||||
train_config: TrainConfig,
|
train_config: TrainConfig,
|
||||||
schedule_config: ScheduleConfig,
|
|
||||||
callbacks: Optional[List[TrainCallback]] = None
|
callbacks: Optional[List[TrainCallback]] = None
|
||||||
):
|
):
|
||||||
self.parameter = parameter
|
|
||||||
self.train_config = train_config
|
self.train_config = train_config
|
||||||
self.schedule_config = schedule_config
|
|
||||||
self.callbacks = callbacks or self._get_default_callbacks()
|
self.callbacks = callbacks or self._get_default_callbacks()
|
||||||
|
|
||||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||||
|
train_config = self.train_config
|
||||||
return [
|
return [
|
||||||
ProgressBarCallback(),
|
ProgressBarCallback(train_config.n_epoch),
|
||||||
CheckpointCallback(self.train_config.checkpoint_interval),
|
CheckpointCallback(train_config.checkpoint_interval, train_config.checkpoint_dir),
|
||||||
GradientClippingCallback(self.train_config.max_grad_norm),
|
GradientClippingCallback(train_config.max_grad_norm),
|
||||||
SchedulerCallback(self.schedule_config),
|
SchedulerCallback(train_config.scheduler),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||||
return (TrainContextBuilder(self)
|
return (TrainContextBuilder(self.train_config)
|
||||||
.with_checkpoint(checkpoint)
|
.with_checkpoint(checkpoint)
|
||||||
.with_optimizer()
|
.with_optimizer()
|
||||||
.with_scheduler()
|
.with_scheduler()
|
||||||
.with_dataloader()
|
.with_dataloader()
|
||||||
|
.with_strategy()
|
||||||
.build())
|
.build())
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
method = getattr(callback, method_name, None)
|
method = getattr(callback, method_name, None)
|
||||||
if method:
|
if method:
|
||||||
method(self, context)
|
method(context)
|
||||||
|
|
||||||
def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
||||||
context = self._build_context(checkpoint)
|
context = self._build_context(checkpoint)
|
||||||
self._call_callbacks('on_train_begin', context)
|
self._call_callbacks('on_train_begin', context)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.parameter.model.train()
|
context.model.train()
|
||||||
# 1.epoch
|
# 1.epoch
|
||||||
for epoch in range(context.epoch, self.train_config.n_epoch):
|
for epoch in range(context.epoch, self.train_config.n_epoch):
|
||||||
context.epoch = epoch
|
context.epoch = epoch
|
||||||
self._call_callbacks('on_epoch_begin', context)
|
self._call_callbacks('on_epoch_begin', context)
|
||||||
|
|
||||||
for batch in context.dataloader:
|
for batch in context.dataloader:
|
||||||
if context.batch_iter % self.train_config.accumulation_steps == 0:
|
if context.iteration % self.train_config.accumulation_steps == 0:
|
||||||
# 2. step
|
# 2. step
|
||||||
self._call_callbacks('on_step_begin', context)
|
self._call_callbacks('on_step_begin', context)
|
||||||
self.train_config.optimizer.step()
|
context.optimizer.step()
|
||||||
self.train_config.optimizer.zero_grad()
|
context.optimizer.zero_grad()
|
||||||
self._call_callbacks('on_step_end', context)
|
self._call_callbacks('on_step_end', context)
|
||||||
|
|
||||||
# 3. batch
|
# 3. batch
|
||||||
self._call_callbacks('on_batch_begin', context)
|
self._call_callbacks('on_batch_begin', context)
|
||||||
loss = self.train_config.strategy(batch)
|
loss = context.strategy(batch)
|
||||||
context.loss = loss.item()
|
context.loss = loss.item()
|
||||||
context.batch_iter += 1
|
context.iteration += 1
|
||||||
|
|
||||||
# to make the loss normalized by accumulation steps
|
# to make the loss normalized by accumulation steps
|
||||||
normalized_loss = loss / self.train_config.accumulation_steps
|
stand_batch = self.train_config.accumulation_steps * self.train_config.nprocs
|
||||||
normalized_loss.backward()
|
stand_loss = loss / stand_batch
|
||||||
|
stand_loss.backward()
|
||||||
|
|
||||||
self._call_callbacks('on_batch_end', context)
|
self._call_callbacks('on_batch_end', context)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,20 @@ from khaosz.trainer import *
|
||||||
|
|
||||||
def test_callback_integration(base_test_env, random_dataset):
|
def test_callback_integration(base_test_env, random_dataset):
|
||||||
"""Test that all callbacks are properly integrated"""
|
"""Test that all callbacks are properly integrated"""
|
||||||
|
schedule_config = CosineScheduleConfig(
|
||||||
|
warmup_steps=10,
|
||||||
|
total_steps=20
|
||||||
|
)
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
||||||
|
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
|
model=base_test_env["model"],
|
||||||
|
strategy='seq',
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
checkpoint_dir=base_test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
|
|
@ -18,36 +28,26 @@ def test_callback_integration(base_test_env, random_dataset):
|
||||||
random_seed=42
|
random_seed=42
|
||||||
)
|
)
|
||||||
|
|
||||||
schedule_config = CosineScheduleConfig(
|
|
||||||
warmup_steps=10,
|
|
||||||
total_steps=20
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create custom callbacks to track calls
|
# Create custom callbacks to track calls
|
||||||
callback_calls = []
|
callback_calls = []
|
||||||
|
|
||||||
class TrackingCallback(TrainCallback):
|
class TrackingCallback(TrainCallback):
|
||||||
def on_train_begin(self, trainer, context):
|
def on_train_begin(self, context):
|
||||||
callback_calls.append('on_train_begin')
|
callback_calls.append('on_train_begin')
|
||||||
|
|
||||||
def on_batch_end(self, trainer, context):
|
def on_batch_end(self, context):
|
||||||
callback_calls.append('on_batch_end')
|
callback_calls.append('on_batch_end')
|
||||||
|
|
||||||
def on_epoch_end(self, trainer, context):
|
def on_epoch_end(self, context):
|
||||||
callback_calls.append('on_epoch_end')
|
callback_calls.append('on_epoch_end')
|
||||||
|
|
||||||
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
|
|
||||||
model_parameter = ModelParameter(
|
|
||||||
base_test_env["model"],
|
|
||||||
base_test_env["tokenizer"],
|
|
||||||
base_test_env["transformer_config"]
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model_parameter,
|
|
||||||
train_config,
|
train_config,
|
||||||
schedule_config,
|
callbacks=[TrackingCallback()]
|
||||||
callbacks=[TrackingCallback(), ProgressBarCallback()]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
|
||||||
|
|
@ -7,35 +7,34 @@ from khaosz.trainer import *
|
||||||
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
"""Simulate early stopping behavior"""
|
"""Simulate early stopping behavior"""
|
||||||
|
|
||||||
|
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
||||||
|
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
|
strategy="seq",
|
||||||
|
scheduler=scheduler,
|
||||||
|
model=base_test_env["model"],
|
||||||
dataset=early_stopping_dataset,
|
dataset=early_stopping_dataset,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
checkpoint_dir=base_test_env["test_dir"],
|
||||||
n_epoch=2,
|
n_epoch=2,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
checkpoint_interval=2,
|
checkpoint_interval=1,
|
||||||
accumulation_steps=2,
|
accumulation_steps=2,
|
||||||
random_seed=np.random.randint(1e4),
|
random_seed=np.random.randint(1e4),
|
||||||
)
|
)
|
||||||
|
|
||||||
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
|
trainer = Trainer(train_config)
|
||||||
model_parameter = ModelParameter(
|
|
||||||
base_test_env["model"],
|
|
||||||
base_test_env["tokenizer"],
|
|
||||||
base_test_env["transformer_config"]
|
|
||||||
)
|
|
||||||
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
|
||||||
trainer = Trainer(model_parameter, train_config, schedule_config)
|
|
||||||
|
|
||||||
# Should handle early stopping gracefully
|
# Should handle early stopping gracefully
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
try:
|
try:
|
||||||
checkpoint = trainer.train()
|
checkpoint = trainer.train()
|
||||||
assert len(checkpoint.loss_list) == 2
|
assert checkpoint.iteration == 2
|
||||||
except Exception:
|
except Exception:
|
||||||
# Handle any exceptions
|
# Handle any exceptions
|
||||||
pass
|
pass
|
||||||
|
|
||||||
checkpoint = trainer.train(checkpoint)
|
checkpoint = trainer.train(checkpoint)
|
||||||
assert len(checkpoint.loss_list) == 10
|
assert checkpoint.iteration == 10
|
||||||
|
|
@ -51,13 +51,6 @@ def test_env(request: pytest.FixtureRequest):
|
||||||
|
|
||||||
shutil.rmtree(test_dir)
|
shutil.rmtree(test_dir)
|
||||||
|
|
||||||
# parameter loader
|
|
||||||
def test_parameter_loader(test_env):
|
|
||||||
loaded_param = ParameterLoader.load(test_env["test_dir"])
|
|
||||||
assert loaded_param.model is not None
|
|
||||||
assert loaded_param.tokenizer is not None
|
|
||||||
assert loaded_param.config == test_env["transformer_config"]
|
|
||||||
|
|
||||||
def test_model_parameter(test_env):
|
def test_model_parameter(test_env):
|
||||||
save_dir = os.path.join(test_env["test_dir"], "save")
|
save_dir = os.path.join(test_env["test_dir"], "save")
|
||||||
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])
|
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])
|
||||||
|
|
|
||||||
|
|
@ -31,10 +31,18 @@ def test_gradient_accumulation(base_test_env, random_dataset):
|
||||||
accumulation_steps_list = [1, 2, 4]
|
accumulation_steps_list = [1, 2, 4]
|
||||||
|
|
||||||
for accumulation_steps in accumulation_steps_list:
|
for accumulation_steps in accumulation_steps_list:
|
||||||
|
schedule_config = CosineScheduleConfig(
|
||||||
|
warmup_steps=10,
|
||||||
|
total_steps=20
|
||||||
|
)
|
||||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
||||||
|
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
dataset=random_dataset,
|
strategy="seq",
|
||||||
|
model=base_test_env["model"],
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
dataset=random_dataset,
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
checkpoint_dir=base_test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
|
|
@ -44,18 +52,7 @@ def test_gradient_accumulation(base_test_env, random_dataset):
|
||||||
random_seed=42
|
random_seed=42
|
||||||
)
|
)
|
||||||
|
|
||||||
schedule_config = CosineScheduleConfig(
|
trainer = Trainer(train_config)
|
||||||
warmup_steps=10,
|
|
||||||
total_steps=20
|
|
||||||
)
|
|
||||||
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
|
|
||||||
model_parameter = ModelParameter(
|
|
||||||
base_test_env["model"],
|
|
||||||
base_test_env["tokenizer"],
|
|
||||||
base_test_env["transformer_config"]
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = Trainer(model_parameter, train_config, schedule_config)
|
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
assert train_config.accumulation_steps == accumulation_steps
|
assert train_config.accumulation_steps == accumulation_steps
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ def test_schedule_factory_random_configs():
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
# Create scheduler using factory
|
# Create scheduler using factory
|
||||||
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
|
scheduler = SchedulerFactory.load(optimizer, config)
|
||||||
|
|
||||||
# Verify scheduler type
|
# Verify scheduler type
|
||||||
if isinstance(config, CosineScheduleConfig):
|
if isinstance(config, CosineScheduleConfig):
|
||||||
|
|
@ -83,7 +83,7 @@ def test_schedule_factory_edge_cases():
|
||||||
|
|
||||||
for config in edge_cases:
|
for config in edge_cases:
|
||||||
config.validate()
|
config.validate()
|
||||||
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
|
scheduler = SchedulerFactory.load(optimizer, config)
|
||||||
assert scheduler is not None
|
assert scheduler is not None
|
||||||
|
|
||||||
# Test multiple steps
|
# Test multiple steps
|
||||||
|
|
@ -97,16 +97,17 @@ def test_schedule_factory_invalid_configs():
|
||||||
# Test invalid configurations that should raise errors
|
# Test invalid configurations that should raise errors
|
||||||
invalid_configs = [
|
invalid_configs = [
|
||||||
# Negative warmup steps
|
# Negative warmup steps
|
||||||
CosineScheduleConfig(warmup_steps=-10, total_steps=1000, min_rate=0.1),
|
{"warmup_steps": -10, "total_steps": 1000, "min_rate": 0.1},
|
||||||
# Total steps less than warmup steps
|
# Total steps less than warmup steps
|
||||||
CosineScheduleConfig(warmup_steps=500, total_steps=400, min_rate=0.1),
|
{"warmup_steps": 500, "total_steps": 400, "min_rate": 0.1},
|
||||||
# Invalid min_rate
|
# Invalid min_rate
|
||||||
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=-0.1),
|
{"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1},
|
||||||
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=1.1),
|
{"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1},
|
||||||
]
|
]
|
||||||
|
|
||||||
for config in invalid_configs:
|
for kwargs in invalid_configs:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
config = CosineScheduleConfig(**kwargs)
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -117,7 +118,7 @@ def test_schedule_factory_state_persistence():
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1)
|
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1)
|
||||||
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
|
scheduler = SchedulerFactory.load(optimizer, config)
|
||||||
|
|
||||||
# Take a few steps
|
# Take a few steps
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
|
|
@ -127,7 +128,7 @@ def test_schedule_factory_state_persistence():
|
||||||
state_dict = scheduler.state_dict()
|
state_dict = scheduler.state_dict()
|
||||||
|
|
||||||
# Create new scheduler and load state
|
# Create new scheduler and load state
|
||||||
new_scheduler = SchedulerFactory.load_scheduler(optimizer, config)
|
new_scheduler = SchedulerFactory.load(optimizer, config)
|
||||||
new_scheduler.load_state_dict(state_dict)
|
new_scheduler.load_state_dict(state_dict)
|
||||||
|
|
||||||
# Verify states match
|
# Verify states match
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,8 @@ import argparse
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from khaosz.config import ParameterLoader, Checkpoint, TrainConfig, CosineScheduleConfig
|
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||||
from khaosz.trainer import Trainer, StrategyFactory
|
from khaosz.trainer import Trainer, SchedulerFactory
|
||||||
from khaosz.data import DatasetLoader
|
from khaosz.data import DatasetLoader
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -36,7 +36,6 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
|
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
|
||||||
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
||||||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
||||||
parser.add_argument("--resume_from_checkpoint", action="store_true", help="Train from checkpoint or not.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -66,16 +65,12 @@ def train(
|
||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
window_size: int,
|
window_size: int,
|
||||||
stride: int,
|
stride: int,
|
||||||
resume_from_checkpoint: bool
|
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo"]
|
assert train_type in ["seq", "sft", "dpo"]
|
||||||
assert os.path.exists(param_path)
|
assert os.path.exists(param_path)
|
||||||
|
|
||||||
parameter = ParameterLoader.load(param_path)
|
parameter = ModelParameter()
|
||||||
checkpoint = None
|
parameter.load(param_path)
|
||||||
|
|
||||||
if isinstance(parameter, Checkpoint) and resume_from_checkpoint:
|
|
||||||
checkpoint = parameter
|
|
||||||
|
|
||||||
if window_size is None:
|
if window_size is None:
|
||||||
window_size = parameter.config.m_len
|
window_size = parameter.config.m_len
|
||||||
|
|
@ -91,13 +86,6 @@ def train(
|
||||||
"pad_token_id": parameter.tokenizer.pad_id,
|
"pad_token_id": parameter.tokenizer.pad_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
strategy = StrategyFactory.load(
|
|
||||||
model,
|
|
||||||
train_type,
|
|
||||||
device,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = DatasetLoader.load(
|
dataset = DatasetLoader.load(
|
||||||
train_type=train_type,
|
train_type=train_type,
|
||||||
load_path=data_root_path,
|
load_path=data_root_path,
|
||||||
|
|
@ -111,16 +99,25 @@ def train(
|
||||||
{"params": [p for n, p in model.named_parameters() if "embed" not in n], "lr": max_lr}
|
{"params": [p for n, p in model.named_parameters() if "embed" not in n], "lr": max_lr}
|
||||||
]
|
]
|
||||||
|
|
||||||
optim = AdamW(
|
optimizer = AdamW(
|
||||||
param_groups,
|
param_groups,
|
||||||
betas=(adamw_beta1, adamw_beta2),
|
betas=(adamw_beta1, adamw_beta2),
|
||||||
weight_decay=adamw_weight_decay
|
weight_decay=adamw_weight_decay
|
||||||
)
|
)
|
||||||
|
|
||||||
|
schedule_config = CosineScheduleConfig(
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
total_steps=len(dataset) * n_epoch // batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
strategy=strategy,
|
model=model,
|
||||||
|
strategy=train_type,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
optimizer=optim,
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
n_epoch=n_epoch,
|
n_epoch=n_epoch,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
|
@ -134,17 +131,8 @@ def train(
|
||||||
pin_memory=pin_memory
|
pin_memory=pin_memory
|
||||||
)
|
)
|
||||||
|
|
||||||
schedule_config = CosineScheduleConfig(
|
trainer = Trainer(train_config)
|
||||||
warmup_steps=warmup_steps,
|
trainer.train()
|
||||||
total_steps=len(dataset) * n_epoch // batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = Trainer(
|
|
||||||
parameter=parameter,
|
|
||||||
train_config=train_config,
|
|
||||||
schedule_config=schedule_config,
|
|
||||||
)
|
|
||||||
trainer.train(checkpoint)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue