refactor(trainer): 优化trainer 结构
This commit is contained in:
parent
82e65ccc21
commit
c98b175cd5
|
|
@ -4,7 +4,6 @@ __author__ = "ViperEkura"
|
|||
from khaosz.api import Khaosz
|
||||
from khaosz.config import (
|
||||
ModelConfig,
|
||||
ParameterLoader,
|
||||
TrainConfig,
|
||||
)
|
||||
from khaosz.model.transformer import Transformer
|
||||
|
|
@ -42,7 +41,6 @@ __all__ = [
|
|||
"PriorityTextSplitter",
|
||||
|
||||
"ModelConfig",
|
||||
"ParameterLoader",
|
||||
"TrainConfig",
|
||||
|
||||
"DatasetLoader",
|
||||
|
|
|
|||
|
|
@ -9,12 +9,13 @@ from khaosz.inference.generator import (
|
|||
RetrievalGenerator,
|
||||
EmbeddingEncoder
|
||||
)
|
||||
from khaosz.config.param_config import ParameterLoader
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
|
||||
|
||||
class Khaosz:
|
||||
def __init__(self, model_dir: str):
|
||||
self.parameter = ParameterLoader.load(model_dir)
|
||||
self.parameter = ModelParameter()
|
||||
self.parameter.load(model_dir)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
self.parameter.to(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from khaosz.config.model_config import ModelConfig
|
||||
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
|
||||
from khaosz.config.param_config import BaseModelIO, ModelParameter
|
||||
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
|
||||
from khaosz.config.train_config import TrainConfig
|
||||
|
||||
|
|
@ -7,8 +7,6 @@ from khaosz.config.train_config import TrainConfig
|
|||
__all__ = [
|
||||
"BaseModelIO",
|
||||
"ModelParameter",
|
||||
"Checkpoint",
|
||||
"ParameterLoader",
|
||||
"ModelConfig",
|
||||
"TrainConfig",
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,8 @@
|
|||
import pickle as pkl
|
||||
import matplotlib.pyplot as plt
|
||||
import safetensors.torch as st
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import safetensors.torch as st
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Self, Union
|
||||
from typing import Optional, Self, Union
|
||||
from pathlib import Path
|
||||
|
||||
from khaosz.data.tokenizer import BpeTokenizer
|
||||
|
|
@ -63,7 +60,7 @@ class BaseModelIO:
|
|||
|
||||
return self
|
||||
|
||||
def to(self, *args, **kwargs) -> Self:
|
||||
def to(self, *args, **kwargs) -> "BaseModelIO":
|
||||
"""Move model to device."""
|
||||
if self.model is not None:
|
||||
self.model.to(*args, **kwargs)
|
||||
|
|
@ -77,142 +74,6 @@ class ModelParameter(BaseModelIO):
|
|||
def save(self, save_dir: Union[str, Path]):
|
||||
self.save_components(save_dir)
|
||||
|
||||
def load(self, load_dir: Union[str, Path]) -> Self:
|
||||
def load(self, load_dir: Union[str, Path]) -> "ModelParameter":
|
||||
return self.load_components(load_dir)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Checkpoint(BaseModelIO):
|
||||
"""Extended model parameters with training state."""
|
||||
|
||||
optimizer_state: Dict[str, Any] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optimizer state."}
|
||||
)
|
||||
scheduler_state: Dict[str, Any] = field(
|
||||
default=None,
|
||||
metadata={"help": "Sampler state."}
|
||||
)
|
||||
loss_list: List[float] = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "List of training losses."}
|
||||
)
|
||||
epoch: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Current epoch."}
|
||||
)
|
||||
batch_iter: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Current iteration."}
|
||||
)
|
||||
|
||||
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
||||
dir_path = Path(directory)
|
||||
return {
|
||||
"loss_plot": dir_path / "loss_plot.png",
|
||||
"training_state": dir_path / "training_state.pkl"
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"optimizer_state": self.optimizer_state,
|
||||
"scheduler_state": self.scheduler_state,
|
||||
"epoch": self.epoch,
|
||||
"batch_iter": self.batch_iter,
|
||||
"loss_list": self.loss_list,
|
||||
}
|
||||
|
||||
def from_dict(self, data: Dict[str, Any]) -> Self:
|
||||
self.optimizer_state = data["optimizer_state"]
|
||||
self.scheduler_state = data["scheduler_state"]
|
||||
self.epoch = data["epoch"]
|
||||
self.batch_iter = data["batch_iter"]
|
||||
self.loss_list = data["loss_list"]
|
||||
|
||||
def save_training_state(self, save_dir: Union[str, Path]):
|
||||
paths = self._get_training_paths(save_dir)
|
||||
|
||||
# Save loss plot
|
||||
self._plot_loss(str(paths["loss_plot"]))
|
||||
|
||||
# Save training state
|
||||
with open(str(paths["training_state"]), "wb") as f:
|
||||
pkl.dump(self.to_dict(), f)
|
||||
|
||||
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
|
||||
paths = self._get_training_paths(load_dir)
|
||||
|
||||
# Load training state
|
||||
with open(str(paths["training_state"]), "rb") as f:
|
||||
train_state = pkl.load(f)
|
||||
|
||||
self.from_dict(train_state)
|
||||
|
||||
return self
|
||||
|
||||
def _plot_loss(self, save_path: str):
|
||||
"""Plot and save loss curve."""
|
||||
if not self.loss_list:
|
||||
return
|
||||
|
||||
batch_iter = len(self.loss_list)
|
||||
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(self.loss_list)
|
||||
plt.title(f"Training Loss - Iteration {batch_iter}")
|
||||
plt.xlabel("Batch")
|
||||
plt.ylabel("Loss")
|
||||
plt.grid(True)
|
||||
plt.savefig(save_path, dpi=30, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
def save(self, save_dir: Union[str, Path]):
|
||||
"""Save complete checkpoint."""
|
||||
self.save_components(save_dir)
|
||||
self.save_training_state(save_dir)
|
||||
|
||||
def load(self, load_dir: Union[str, Path]) -> Self:
|
||||
"""Load complete checkpoint."""
|
||||
self.load_components(load_dir)
|
||||
self.load_training_state(load_dir)
|
||||
return self
|
||||
|
||||
|
||||
class ParameterLoader:
|
||||
"""Factory class for loading model parameters or checkpoints."""
|
||||
|
||||
@staticmethod
|
||||
def load(load_dir: Union[str, Path]) -> Union[ModelParameter, Checkpoint]:
|
||||
"""Load either ModelParameter or Checkpoint based on directory contents."""
|
||||
load_dir = Path(load_dir)
|
||||
|
||||
# Check for training-specific files
|
||||
loss_file = load_dir / "loss.pkl"
|
||||
has_training_data = loss_file.exists()
|
||||
|
||||
# Create appropriate instance
|
||||
if has_training_data:
|
||||
checkpoint = Checkpoint()
|
||||
checkpoint.load(str(load_dir))
|
||||
return checkpoint
|
||||
else:
|
||||
params = ModelParameter()
|
||||
params.load(str(load_dir))
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def create_checkpoint(
|
||||
model: nn.Module,
|
||||
tokenizer: BpeTokenizer,
|
||||
config: ModelConfig,
|
||||
loss_list: Optional[list[float]] = None,
|
||||
optimizer: Optional[optim.Optimizer] = None,
|
||||
) -> Checkpoint:
|
||||
"""Convenience method to create a training checkpoint."""
|
||||
return Checkpoint(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
loss_list=loss_list or [],
|
||||
optimizer_state=optimizer
|
||||
)
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Literal, Dict
|
||||
from typing import Any, Dict
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
|
@ -39,7 +39,10 @@ class CosineScheduleConfig(ScheduleConfig):
|
|||
default=None,
|
||||
metadata={"help": "Total training steps for cosine schedule."}
|
||||
)
|
||||
schedule_type: Literal["cosine"] = "cosine"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.schedule_type = "cosine"
|
||||
self.validate()
|
||||
|
||||
def get_kwargs(self) -> Dict[str, Any]:
|
||||
if self.total_steps is None:
|
||||
|
|
@ -68,7 +71,10 @@ class SGDRScheduleConfig(ScheduleConfig):
|
|||
default=2,
|
||||
metadata={"help": "Multiplier for cycle length growth."}
|
||||
)
|
||||
schedule_type: Literal["sgdr"] = "sgdr"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.schedule_type = "sgdr"
|
||||
self.validate()
|
||||
|
||||
def get_kwargs(self) -> Dict[str, Any]:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,15 +1,20 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from khaosz.trainer.strategy import BaseStrategy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
strategy: "BaseStrategy" = field(
|
||||
# basic setting
|
||||
model: nn.Module = field(
|
||||
default=None,
|
||||
metadata={"help": "Model for training."}
|
||||
)
|
||||
strategy: str = field(
|
||||
default=None,
|
||||
metadata={"help": "Training strategy."}
|
||||
)
|
||||
|
|
@ -21,9 +26,9 @@ class TrainConfig:
|
|||
default=None,
|
||||
metadata={"help": "Optimizer for training."}
|
||||
)
|
||||
checkpoint_dir: str = field(
|
||||
default="./checkpoint",
|
||||
metadata={"help": "Checkpoint directory."}
|
||||
scheduler: LRScheduler = field(
|
||||
default=None,
|
||||
metadata={"help": "Scheduler for training."}
|
||||
)
|
||||
n_epoch: int = field(
|
||||
default=1,
|
||||
|
|
@ -33,6 +38,16 @@ class TrainConfig:
|
|||
default=4,
|
||||
metadata={"help": "Batch size for training."}
|
||||
)
|
||||
accumulation_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of iterations between steps."}
|
||||
)
|
||||
max_grad_norm: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Maximum gradient norm."}
|
||||
)
|
||||
|
||||
# checkpoint setting
|
||||
start_epoch: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Start epoch for training."}
|
||||
|
|
@ -41,18 +56,14 @@ class TrainConfig:
|
|||
default=0,
|
||||
metadata={"help": "Start batch iteration for training."}
|
||||
)
|
||||
checkpoint_dir: str = field(
|
||||
default="./checkpoint",
|
||||
metadata={"help": "Checkpoint directory."}
|
||||
)
|
||||
checkpoint_interval: int = field(
|
||||
default=5000,
|
||||
metadata={"help": "Number of iterations between checkpoints."}
|
||||
)
|
||||
accumulation_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of iterations between steps."}
|
||||
)
|
||||
max_grad_norm: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Maximum gradient norm."}
|
||||
)
|
||||
|
||||
# dataloader setting
|
||||
random_seed: int = field(
|
||||
|
|
@ -77,3 +88,9 @@ class TrainConfig:
|
|||
default=1,
|
||||
metadata={"help": "Number of processes for distributed training."}
|
||||
)
|
||||
|
||||
# others
|
||||
kwargs: dict = field(
|
||||
default_factory=dict,
|
||||
metadata={"help": "Other arguments."}
|
||||
)
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
import os
|
||||
import pickle as pkl
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class Checkpoint:
|
||||
def __init__(
|
||||
self,
|
||||
optimizer_state: Optimizer,
|
||||
scheduler_state: LRScheduler,
|
||||
epoch: int = 0,
|
||||
iteration: int = 0,
|
||||
metrics: Optional[Dict[str, list]] = None,
|
||||
):
|
||||
self.optimizer_state = optimizer_state
|
||||
self.scheduler_state = scheduler_state
|
||||
self.epoch, self.iteration = epoch, iteration
|
||||
self.metrics = metrics
|
||||
|
||||
def save(self, save_dir: str, save_metric_plot=True) -> None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
train_state = {
|
||||
"epoch": self.epoch,
|
||||
"iteration": self.iteration,
|
||||
"metrics": self.metrics,
|
||||
"optimizer_state": self.optimizer_state,
|
||||
"scheduler_state": self.scheduler_state,
|
||||
}
|
||||
|
||||
with open(os.path.join(save_dir, "train_state.pkl"), "wb") as f:
|
||||
pkl.dump(train_state, f)
|
||||
|
||||
if save_metric_plot and self.metrics:
|
||||
self._plot_metrics()
|
||||
|
||||
def load(self, save_dir: str) -> "Checkpoint":
|
||||
if not os.path.exists(save_dir):
|
||||
raise FileNotFoundError(f"Checkpoint directory {save_dir} does not exist.")
|
||||
|
||||
with open(os.path.join(save_dir, "train_state.pkl"), "rb") as f:
|
||||
train_state = pkl.load(f)
|
||||
self.epoch = train_state["epoch"]
|
||||
self.iteration = train_state["iteration"]
|
||||
self.metrics = train_state["metrics"]
|
||||
self.optimizer_state = train_state["optimizer_state"]
|
||||
self.scheduler_state = train_state["scheduler_state"]
|
||||
|
||||
return self
|
||||
|
||||
def _plot_metrics(self):
|
||||
for metric_name, metric_value in self.metrics.items():
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(metric_value, label=metric_name)
|
||||
plt.xlabel('Step')
|
||||
plt.ylabel('Value')
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
plt.savefig(f'{metric_name}.png', dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
|
@ -151,7 +151,7 @@ class SchedulerFactory:
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def load_scheduler(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler:
|
||||
def load(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler:
|
||||
kwargs = scedule_config.get_kwargs()
|
||||
schedule_type = kwargs.pop("schedule_type")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from torch import Tensor
|
||||
from typing import Any, Tuple, Callable, Dict, Union
|
||||
from typing import Any, Callable, Dict, Union
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
|
|
@ -41,7 +41,7 @@ class BaseStrategy(ABC):
|
|||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
return self.compute_loss(batch)
|
||||
|
||||
|
||||
|
|
@ -94,7 +94,7 @@ class DpoStrategy(BaseStrategy):
|
|||
self.pad_token_id = pad_token_id
|
||||
self.beta = beta
|
||||
|
||||
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
batch = move_to_device(batch, self.device)
|
||||
good_ids, bad_ids = batch["chosen"], batch["rejected"]
|
||||
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
|
||||
|
|
@ -115,41 +115,6 @@ class DpoStrategy(BaseStrategy):
|
|||
return dpo_loss
|
||||
|
||||
|
||||
class PpoStrategy(BaseStrategy):
|
||||
def __init__(self, model, pad_token_id, epsilon):
|
||||
super().__init__(model)
|
||||
ref_model = copy.deepcopy(self.model)
|
||||
ref_model.requires_grad_(False)
|
||||
ref_model.eval()
|
||||
|
||||
self.ref_model = ref_model
|
||||
self.pad_token_id = pad_token_id
|
||||
self.epsilon = epsilon
|
||||
|
||||
def ppo_clip_loss_masked(
|
||||
self,
|
||||
log_probs: Tensor,
|
||||
old_log_probs: Tensor,
|
||||
advantages: Tensor,
|
||||
values: Tensor,
|
||||
returns: Tensor,
|
||||
mask: Tensor,
|
||||
clip_eps: float=0.2,
|
||||
):
|
||||
ratio = torch.exp(log_probs - old_log_probs)
|
||||
surr1 = ratio * advantages
|
||||
surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
|
||||
policy_loss = -torch.min(surr1, surr2).masked_select(mask).mean()
|
||||
|
||||
value_loss = F.mse_loss(values.masked_select(mask),
|
||||
returns.masked_select(mask))
|
||||
|
||||
entropy = -(log_probs.exp() * log_probs).masked_select(mask).mean()
|
||||
entropy_loss = -entropy
|
||||
return policy_loss, value_loss, entropy_loss
|
||||
|
||||
|
||||
|
||||
class StrategyFactory:
|
||||
|
||||
def load(model, train_type, device, **kwargs):
|
||||
|
|
|
|||
|
|
@ -5,10 +5,9 @@ import time
|
|||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from typing import List, Optional, Protocol, TYPE_CHECKING
|
||||
|
||||
from khaosz.config import ScheduleConfig
|
||||
from khaosz.trainer.metric_util import (
|
||||
grad_max,
|
||||
grad_min,
|
||||
|
|
@ -17,9 +16,9 @@ from khaosz.trainer.metric_util import (
|
|||
grad_std,
|
||||
grad_nan_num
|
||||
)
|
||||
from khaosz.trainer.checkpoint import Checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from khaosz.trainer.trainer import Trainer
|
||||
from khaosz.trainer.train_context import TrainContext
|
||||
|
||||
|
||||
|
|
@ -28,31 +27,31 @@ class TrainCallback(Protocol):
|
|||
Callback interface for trainer.
|
||||
"""
|
||||
|
||||
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
def on_train_begin(self, context: 'TrainContext'):
|
||||
""" Called at the beginning of training. """
|
||||
|
||||
def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
def on_train_end(self, context: 'TrainContext'):
|
||||
""" Called at the end of training. """
|
||||
|
||||
def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
def on_epoch_begin(self, context: 'TrainContext'):
|
||||
""" Called at the beginning of each epoch. """
|
||||
|
||||
def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
def on_epoch_end(self, context: 'TrainContext'):
|
||||
""" Called at the end of each epoch. """
|
||||
|
||||
def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
def on_step_begin(self, context: 'TrainContext'):
|
||||
""" Called at the beginning of each step. """
|
||||
|
||||
def on_step_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
def on_step_end(self, context: 'TrainContext'):
|
||||
""" Called at the end of each step."""
|
||||
|
||||
def on_batch_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
def on_batch_begin(self, context: 'TrainContext'):
|
||||
""" Called at the beginning of each batch. """
|
||||
|
||||
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
def on_batch_end(self, context: 'TrainContext'):
|
||||
""" Called at the end of each batch. """
|
||||
|
||||
def on_error(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
def on_error(self, context: 'TrainContext'):
|
||||
""" Called when an error occurs during training. """
|
||||
|
||||
|
||||
|
|
@ -63,29 +62,27 @@ class GradientClippingCallback(TrainCallback):
|
|||
def __init__(self, max_grad_norm: float):
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
def on_step_begin(self, context: 'TrainContext'):
|
||||
_ = context
|
||||
clip_grad_norm_(trainer.parameter.model.parameters(), self.max_grad_norm)
|
||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||
|
||||
|
||||
class SchedulerCallback(TrainCallback):
|
||||
"""
|
||||
Scheduler callback for trainer.
|
||||
"""
|
||||
def __init__(self, schedule_config: ScheduleConfig):
|
||||
self.schedule_config = schedule_config
|
||||
self.scheduler: Optional[LambdaLR] = None
|
||||
def __init__(self, scheduler: LRScheduler):
|
||||
self.scheduler: LRScheduler = scheduler
|
||||
|
||||
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
|
||||
for group in trainer.train_config.optimizer.param_groups:
|
||||
def on_train_begin(self, context: 'TrainContext'):
|
||||
for group in context.optimizer.param_groups:
|
||||
if "initial_lr" not in group:
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
self.scheduler = context.scheduler
|
||||
|
||||
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
_ = trainer, context
|
||||
def on_batch_end(self, context: 'TrainContext'):
|
||||
_ = context
|
||||
if self.scheduler:
|
||||
self.scheduler.step()
|
||||
|
||||
|
|
@ -94,54 +91,59 @@ class CheckpointCallback(TrainCallback):
|
|||
"""
|
||||
Checkpoint callback for trainer.
|
||||
"""
|
||||
def __init__(self, checkpoint_interval: int):
|
||||
self.checkpoint_interval = checkpoint_interval
|
||||
def __init__(self, interval: int, save_dir: str):
|
||||
self.interval = interval
|
||||
self.save_dir = save_dir
|
||||
self.checkpoint = None
|
||||
self.last_ckpt_iter = 0
|
||||
|
||||
def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.batch_iter}")
|
||||
context.checkpoint.optimizer_state = context.optimizer.state_dict()
|
||||
context.checkpoint.scheduler_state = context.scheduler.state_dict()
|
||||
context.checkpoint.epoch = context.epoch
|
||||
context.checkpoint.batch_iter = context.batch_iter
|
||||
context.checkpoint.save(save_path)
|
||||
self.last_ckpt_iter = context.batch_iter
|
||||
def _save_checkpoint(self, context: 'TrainContext'):
|
||||
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}")
|
||||
self.checkpoint = Checkpoint(
|
||||
context.optimizer.state_dict(),
|
||||
context.scheduler.state_dict(),
|
||||
context.epoch,
|
||||
context.iteration
|
||||
)
|
||||
self.checkpoint.save(save_path)
|
||||
self.last_ckpt_iter = context.iteration
|
||||
|
||||
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
context.checkpoint.loss_list.append(context.loss)
|
||||
def on_batch_end(self, context: 'TrainContext'):
|
||||
if context.iteration - self.last_ckpt_iter >= self.interval:
|
||||
self._save_checkpoint(context)
|
||||
|
||||
if context.batch_iter - self.last_ckpt_iter >= self.checkpoint_interval:
|
||||
self._save_checkpoint(trainer, context)
|
||||
def on_train_end(self, context: 'TrainContext'):
|
||||
if context.iteration != self.last_ckpt_iter:
|
||||
self._save_checkpoint(context)
|
||||
|
||||
def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
if context.batch_iter != self.last_ckpt_iter:
|
||||
self._save_checkpoint(trainer, context)
|
||||
def on_error(self, context: 'TrainContext'):
|
||||
self._save_checkpoint(context)
|
||||
|
||||
|
||||
class ProgressBarCallback(TrainCallback):
|
||||
"""
|
||||
Progress bar callback for trainer.
|
||||
"""
|
||||
def __init__(self):
|
||||
def __init__(self, num_epoch: int):
|
||||
self.num_epoch = num_epoch
|
||||
self.progress_bar: tqdm = None
|
||||
|
||||
def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
def on_epoch_begin(self, context: 'TrainContext'):
|
||||
self.progress_bar = tqdm(
|
||||
context.dataloader,
|
||||
desc=f"Epoch {context.epoch+1}/{trainer.train_config.n_epoch}",
|
||||
desc=f"Epoch {context.epoch+1}/{self.num_epoch}",
|
||||
dynamic_ncols=True
|
||||
)
|
||||
|
||||
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
_ = trainer
|
||||
def on_batch_end(self, context: 'TrainContext'):
|
||||
self.progress_bar.set_postfix({
|
||||
"loss": f"{context.loss:.4f}",
|
||||
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}"
|
||||
})
|
||||
self.progress_bar.update(1)
|
||||
|
||||
def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
_ = trainer, context
|
||||
def on_epoch_end(self, context: 'TrainContext'):
|
||||
_ = context
|
||||
if self.progress_bar:
|
||||
self.progress_bar.close()
|
||||
|
||||
|
|
@ -177,13 +179,13 @@ class StepMonitorCallback(TrainCallback):
|
|||
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _handle_info(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
def _handle_info(self, context: 'TrainContext'):
|
||||
""" Logs training information to console and file. """
|
||||
|
||||
log_data = {
|
||||
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
"epoch": context.epoch,
|
||||
"iter": context.batch_iter,
|
||||
"iter": context.iteration,
|
||||
"metrics": self.metrics,
|
||||
}
|
||||
|
||||
|
|
@ -193,34 +195,34 @@ class StepMonitorCallback(TrainCallback):
|
|||
elif metric == 'lr':
|
||||
log_data[metric] = context.optimizer.param_groups[-1]['lr']
|
||||
elif metric == 'grad_norm':
|
||||
log_data[metric] = grad_norm(trainer.parameter.model)
|
||||
log_data[metric] = grad_norm(context.model)
|
||||
elif metric == 'grad_std':
|
||||
log_data[metric] = grad_std(trainer.parameter.model)
|
||||
log_data[metric] = grad_std(context.model)
|
||||
elif metric == 'grad_max':
|
||||
log_data[metric] = grad_max(trainer.parameter.model)
|
||||
log_data[metric] = grad_max(context.model)
|
||||
elif metric == 'grad_min':
|
||||
log_data[metric] = grad_min(trainer.parameter.model)
|
||||
log_data[metric] = grad_min(context.model)
|
||||
elif metric == 'grad_mean':
|
||||
log_data[metric] = grad_mean(trainer.parameter.model)
|
||||
log_data[metric] = grad_mean(context.model)
|
||||
elif metric == 'grad_nan_num':
|
||||
log_data[metric] = grad_nan_num(trainer.parameter.model)
|
||||
log_data[metric] = grad_nan_num(context.model)
|
||||
else:
|
||||
raise ValueError(f"Invalid metric: {metric}")
|
||||
|
||||
return log_data
|
||||
|
||||
def _handle_log(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
def _handle_log(self, context: 'TrainContext'):
|
||||
""" Logs training information to console and file. """
|
||||
log_data = self._handle_info(trainer, context)
|
||||
log_data = self._handle_info(context)
|
||||
try:
|
||||
log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.batch_iter}.json"
|
||||
log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.iteration}.json"
|
||||
with open(log_file, 'a') as f:
|
||||
json.dump(log_data, f, indent=4)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def on_step_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
def on_step_end(self, context: 'TrainContext'):
|
||||
if self.step_num % self.log_interval == 0:
|
||||
self._handle_log(trainer, context)
|
||||
self._handle_log(context)
|
||||
|
||||
self.step_num += 1
|
||||
|
|
@ -1,61 +1,60 @@
|
|||
from dataclasses import dataclass, field, fields
|
||||
from typing import Optional, Self, TYPE_CHECKING
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from khaosz.config import Checkpoint
|
||||
from khaosz.data import ResumableDistributedSampler
|
||||
from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||
from khaosz.parallel.utils import get_world_size, get_rank
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from khaosz.trainer.trainer import Trainer
|
||||
from khaosz.data import ResumableDistributedSampler
|
||||
from khaosz.trainer.checkpoint import Checkpoint
|
||||
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
|
||||
from khaosz.config.train_config import TrainConfig
|
||||
from khaosz.parallel.utils import get_current_device, get_world_size, get_rank
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Self
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainContext:
|
||||
model: nn.Module = field(default=None)
|
||||
strategy: BaseStrategy = field(default=None)
|
||||
dataloader: DataLoader = field(default=None)
|
||||
optimizer: Optimizer = field(default=None)
|
||||
scheduler: BaseScheduler = field(default=None)
|
||||
scheduler: LRScheduler = field(default=None)
|
||||
checkpoint: Checkpoint = field(default=None)
|
||||
|
||||
epoch: int = field(default=0)
|
||||
batch_iter: int = field(default=0)
|
||||
iteration: int = field(default=0)
|
||||
loss: float = field(default=0.0)
|
||||
|
||||
wolrd_size: int = field(default=1)
|
||||
rank: int = field(default=0)
|
||||
|
||||
def asdict(self) -> dict:
|
||||
return {field.name: getattr(self, field.name)
|
||||
for field in fields(self)}
|
||||
|
||||
|
||||
class TrainContextBuilder:
|
||||
def __init__(self, trainer: 'Trainer'):
|
||||
self.trainer = trainer
|
||||
def __init__(self, config: TrainConfig):
|
||||
self.config = config
|
||||
self._context: TrainContext = None
|
||||
|
||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
self._context = TrainContext()
|
||||
if checkpoint is None:
|
||||
checkpoint = Checkpoint(
|
||||
model=self.trainer.parameter.model,
|
||||
tokenizer=self.trainer.parameter.tokenizer,
|
||||
config=self.trainer.parameter.config,
|
||||
optimizer_state=self.config.optimizer.state_dict(),
|
||||
scheduler_state=self.config.scheduler.state_dict(),
|
||||
)
|
||||
else:
|
||||
# resume from the assigned checkpoint or assigned iteration
|
||||
self._context.epoch = max(checkpoint.epoch, self.trainer.train_config.start_epoch)
|
||||
self._context.batch_iter = max(checkpoint.batch_iter, self.trainer.train_config.start_batch)
|
||||
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
||||
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
||||
|
||||
self._context.checkpoint = checkpoint
|
||||
return self
|
||||
|
||||
def with_optimizer(self) -> Self:
|
||||
optimizer = self.config.optimizer
|
||||
if self._context is None:
|
||||
raise RuntimeError("Must call with_checkpoint() before with_optimizer()")
|
||||
|
||||
optimizer = self.trainer.train_config.optimizer
|
||||
|
||||
if self._context.checkpoint and self._context.checkpoint.optimizer_state:
|
||||
optimizer.load_state_dict(self._context.checkpoint.optimizer_state)
|
||||
|
||||
|
|
@ -67,13 +66,7 @@ class TrainContextBuilder:
|
|||
return self
|
||||
|
||||
def with_scheduler(self) -> Self:
|
||||
if not hasattr(self._context, 'optimizer') or self._context.optimizer is None:
|
||||
raise RuntimeError("Must call with_optimizer() before with_scheduler()")
|
||||
|
||||
optimizer = self.trainer.train_config.optimizer
|
||||
schedule_config = self.trainer.schedule_config
|
||||
scheduler = SchedulerFactory.load_scheduler(optimizer, schedule_config)
|
||||
|
||||
scheduler = self.config.scheduler
|
||||
if self._context.checkpoint and self._context.checkpoint.scheduler_state:
|
||||
scheduler.load_state_dict(self._context.checkpoint.scheduler_state)
|
||||
|
||||
|
|
@ -85,29 +78,41 @@ class TrainContextBuilder:
|
|||
return self
|
||||
|
||||
def with_dataloader(self) -> Self:
|
||||
# fix: change batch level batch_iter to sample level offset
|
||||
sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size
|
||||
# fix: change batch level iteration to sample level offset
|
||||
config = self.config
|
||||
sampler_offset = self._context.iteration * config.batch_size
|
||||
resumeable_sampler = ResumableDistributedSampler(
|
||||
data_source=self.trainer.train_config.dataset,
|
||||
data_source=config.dataset,
|
||||
start_epoch=self._context.epoch,
|
||||
start_iter=sampler_offset,
|
||||
seed=self.trainer.train_config.random_seed
|
||||
seed=config.random_seed
|
||||
)
|
||||
|
||||
dataloader = DataLoader(
|
||||
self.trainer.train_config.dataset,
|
||||
batch_size=self.trainer.train_config.batch_size,
|
||||
config.dataset,
|
||||
batch_size=config.batch_size,
|
||||
sampler=resumeable_sampler,
|
||||
num_workers=self.trainer.train_config.num_workers,
|
||||
pin_memory=self.trainer.train_config.pin_memory,
|
||||
prefetch_factor=self.trainer.train_config.prefetch_factor
|
||||
num_workers=config.num_workers,
|
||||
pin_memory=config.pin_memory,
|
||||
prefetch_factor=config.prefetch_factor
|
||||
)
|
||||
self._context.dataloader = dataloader
|
||||
return self
|
||||
|
||||
def build(self) -> TrainContext:
|
||||
def with_strategy(self) -> Self:
|
||||
device = get_current_device()
|
||||
self._context.strategy = StrategyFactory.load(
|
||||
model=self.config.model,
|
||||
train_type=self.config.strategy,
|
||||
device=device,
|
||||
**self.config.kwargs
|
||||
)
|
||||
return self
|
||||
|
||||
if self.trainer.train_config.nprocs > 1:
|
||||
def build(self) -> TrainContext:
|
||||
self._context.model = self.config.model
|
||||
|
||||
if self.config.nprocs > 1:
|
||||
self._context.wolrd_size = get_world_size()
|
||||
self._context.rank = get_rank()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,6 @@
|
|||
import logging
|
||||
from typing import Optional, List
|
||||
from khaosz.config import (
|
||||
ModelParameter,
|
||||
Checkpoint,
|
||||
ScheduleConfig,
|
||||
TrainConfig
|
||||
)
|
||||
from khaosz.config import TrainConfig
|
||||
from khaosz.trainer.train_callback import (
|
||||
TrainCallback,
|
||||
ProgressBarCallback,
|
||||
|
|
@ -13,7 +8,7 @@ from khaosz.trainer.train_callback import (
|
|||
GradientClippingCallback,
|
||||
SchedulerCallback
|
||||
)
|
||||
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
||||
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder, Checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -21,66 +16,65 @@ logger = logging.getLogger(__name__)
|
|||
class Trainer:
|
||||
def __init__(
|
||||
self,
|
||||
parameter: ModelParameter,
|
||||
train_config: TrainConfig,
|
||||
schedule_config: ScheduleConfig,
|
||||
callbacks: Optional[List[TrainCallback]] = None
|
||||
):
|
||||
self.parameter = parameter
|
||||
self.train_config = train_config
|
||||
self.schedule_config = schedule_config
|
||||
self.callbacks = callbacks or self._get_default_callbacks()
|
||||
|
||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||
train_config = self.train_config
|
||||
return [
|
||||
ProgressBarCallback(),
|
||||
CheckpointCallback(self.train_config.checkpoint_interval),
|
||||
GradientClippingCallback(self.train_config.max_grad_norm),
|
||||
SchedulerCallback(self.schedule_config),
|
||||
ProgressBarCallback(train_config.n_epoch),
|
||||
CheckpointCallback(train_config.checkpoint_interval, train_config.checkpoint_dir),
|
||||
GradientClippingCallback(train_config.max_grad_norm),
|
||||
SchedulerCallback(train_config.scheduler),
|
||||
]
|
||||
|
||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||
return (TrainContextBuilder(self)
|
||||
return (TrainContextBuilder(self.train_config)
|
||||
.with_checkpoint(checkpoint)
|
||||
.with_optimizer()
|
||||
.with_scheduler()
|
||||
.with_dataloader()
|
||||
.with_strategy()
|
||||
.build())
|
||||
|
||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||
for callback in self.callbacks:
|
||||
method = getattr(callback, method_name, None)
|
||||
if method:
|
||||
method(self, context)
|
||||
method(context)
|
||||
|
||||
def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
||||
context = self._build_context(checkpoint)
|
||||
self._call_callbacks('on_train_begin', context)
|
||||
|
||||
try:
|
||||
self.parameter.model.train()
|
||||
context.model.train()
|
||||
# 1.epoch
|
||||
for epoch in range(context.epoch, self.train_config.n_epoch):
|
||||
context.epoch = epoch
|
||||
self._call_callbacks('on_epoch_begin', context)
|
||||
|
||||
for batch in context.dataloader:
|
||||
if context.batch_iter % self.train_config.accumulation_steps == 0:
|
||||
if context.iteration % self.train_config.accumulation_steps == 0:
|
||||
# 2. step
|
||||
self._call_callbacks('on_step_begin', context)
|
||||
self.train_config.optimizer.step()
|
||||
self.train_config.optimizer.zero_grad()
|
||||
context.optimizer.step()
|
||||
context.optimizer.zero_grad()
|
||||
self._call_callbacks('on_step_end', context)
|
||||
|
||||
# 3. batch
|
||||
self._call_callbacks('on_batch_begin', context)
|
||||
loss = self.train_config.strategy(batch)
|
||||
loss = context.strategy(batch)
|
||||
context.loss = loss.item()
|
||||
context.batch_iter += 1
|
||||
context.iteration += 1
|
||||
|
||||
# to make the loss normalized by accumulation steps
|
||||
normalized_loss = loss / self.train_config.accumulation_steps
|
||||
normalized_loss.backward()
|
||||
stand_batch = self.train_config.accumulation_steps * self.train_config.nprocs
|
||||
stand_loss = loss / stand_batch
|
||||
stand_loss.backward()
|
||||
|
||||
self._call_callbacks('on_batch_end', context)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,10 +5,20 @@ from khaosz.trainer import *
|
|||
|
||||
def test_callback_integration(base_test_env, random_dataset):
|
||||
"""Test that all callbacks are properly integrated"""
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=10,
|
||||
total_steps=20
|
||||
)
|
||||
|
||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
||||
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||
|
||||
train_config = TrainConfig(
|
||||
model=base_test_env["model"],
|
||||
strategy='seq',
|
||||
dataset=random_dataset,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
checkpoint_dir=base_test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
batch_size=2,
|
||||
|
|
@ -18,36 +28,26 @@ def test_callback_integration(base_test_env, random_dataset):
|
|||
random_seed=42
|
||||
)
|
||||
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=10,
|
||||
total_steps=20
|
||||
)
|
||||
|
||||
|
||||
# Create custom callbacks to track calls
|
||||
callback_calls = []
|
||||
|
||||
class TrackingCallback(TrainCallback):
|
||||
def on_train_begin(self, trainer, context):
|
||||
def on_train_begin(self, context):
|
||||
callback_calls.append('on_train_begin')
|
||||
|
||||
def on_batch_end(self, trainer, context):
|
||||
def on_batch_end(self, context):
|
||||
callback_calls.append('on_batch_end')
|
||||
|
||||
def on_epoch_end(self, trainer, context):
|
||||
def on_epoch_end(self, context):
|
||||
callback_calls.append('on_epoch_end')
|
||||
|
||||
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
|
||||
model_parameter = ModelParameter(
|
||||
base_test_env["model"],
|
||||
base_test_env["tokenizer"],
|
||||
base_test_env["transformer_config"]
|
||||
)
|
||||
|
||||
|
||||
trainer = Trainer(
|
||||
model_parameter,
|
||||
train_config,
|
||||
schedule_config,
|
||||
callbacks=[TrackingCallback(), ProgressBarCallback()]
|
||||
callbacks=[TrackingCallback()]
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
|
|
|||
|
|
@ -7,35 +7,34 @@ from khaosz.trainer import *
|
|||
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||
"""Simulate early stopping behavior"""
|
||||
|
||||
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
||||
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||
|
||||
train_config = TrainConfig(
|
||||
strategy="seq",
|
||||
scheduler=scheduler,
|
||||
model=base_test_env["model"],
|
||||
dataset=early_stopping_dataset,
|
||||
optimizer=optimizer,
|
||||
checkpoint_dir=base_test_env["test_dir"],
|
||||
n_epoch=2,
|
||||
batch_size=2,
|
||||
checkpoint_interval=2,
|
||||
checkpoint_interval=1,
|
||||
accumulation_steps=2,
|
||||
random_seed=np.random.randint(1e4),
|
||||
)
|
||||
|
||||
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
|
||||
model_parameter = ModelParameter(
|
||||
base_test_env["model"],
|
||||
base_test_env["tokenizer"],
|
||||
base_test_env["transformer_config"]
|
||||
)
|
||||
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||
trainer = Trainer(model_parameter, train_config, schedule_config)
|
||||
trainer = Trainer(train_config)
|
||||
|
||||
# Should handle early stopping gracefully
|
||||
checkpoint = None
|
||||
try:
|
||||
checkpoint = trainer.train()
|
||||
assert len(checkpoint.loss_list) == 2
|
||||
assert checkpoint.iteration == 2
|
||||
except Exception:
|
||||
# Handle any exceptions
|
||||
pass
|
||||
|
||||
checkpoint = trainer.train(checkpoint)
|
||||
assert len(checkpoint.loss_list) == 10
|
||||
assert checkpoint.iteration == 10
|
||||
|
|
@ -51,13 +51,6 @@ def test_env(request: pytest.FixtureRequest):
|
|||
|
||||
shutil.rmtree(test_dir)
|
||||
|
||||
# parameter loader
|
||||
def test_parameter_loader(test_env):
|
||||
loaded_param = ParameterLoader.load(test_env["test_dir"])
|
||||
assert loaded_param.model is not None
|
||||
assert loaded_param.tokenizer is not None
|
||||
assert loaded_param.config == test_env["transformer_config"]
|
||||
|
||||
def test_model_parameter(test_env):
|
||||
save_dir = os.path.join(test_env["test_dir"], "save")
|
||||
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])
|
||||
|
|
|
|||
|
|
@ -31,10 +31,18 @@ def test_gradient_accumulation(base_test_env, random_dataset):
|
|||
accumulation_steps_list = [1, 2, 4]
|
||||
|
||||
for accumulation_steps in accumulation_steps_list:
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=10,
|
||||
total_steps=20
|
||||
)
|
||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
||||
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||
train_config = TrainConfig(
|
||||
dataset=random_dataset,
|
||||
strategy="seq",
|
||||
model=base_test_env["model"],
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
dataset=random_dataset,
|
||||
checkpoint_dir=base_test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
batch_size=2,
|
||||
|
|
@ -44,18 +52,7 @@ def test_gradient_accumulation(base_test_env, random_dataset):
|
|||
random_seed=42
|
||||
)
|
||||
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=10,
|
||||
total_steps=20
|
||||
)
|
||||
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
|
||||
model_parameter = ModelParameter(
|
||||
base_test_env["model"],
|
||||
base_test_env["tokenizer"],
|
||||
base_test_env["transformer_config"]
|
||||
)
|
||||
|
||||
trainer = Trainer(model_parameter, train_config, schedule_config)
|
||||
trainer = Trainer(train_config)
|
||||
trainer.train()
|
||||
|
||||
assert train_config.accumulation_steps == accumulation_steps
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ def test_schedule_factory_random_configs():
|
|||
config.validate()
|
||||
|
||||
# Create scheduler using factory
|
||||
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
|
||||
scheduler = SchedulerFactory.load(optimizer, config)
|
||||
|
||||
# Verify scheduler type
|
||||
if isinstance(config, CosineScheduleConfig):
|
||||
|
|
@ -83,7 +83,7 @@ def test_schedule_factory_edge_cases():
|
|||
|
||||
for config in edge_cases:
|
||||
config.validate()
|
||||
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
|
||||
scheduler = SchedulerFactory.load(optimizer, config)
|
||||
assert scheduler is not None
|
||||
|
||||
# Test multiple steps
|
||||
|
|
@ -97,16 +97,17 @@ def test_schedule_factory_invalid_configs():
|
|||
# Test invalid configurations that should raise errors
|
||||
invalid_configs = [
|
||||
# Negative warmup steps
|
||||
CosineScheduleConfig(warmup_steps=-10, total_steps=1000, min_rate=0.1),
|
||||
{"warmup_steps": -10, "total_steps": 1000, "min_rate": 0.1},
|
||||
# Total steps less than warmup steps
|
||||
CosineScheduleConfig(warmup_steps=500, total_steps=400, min_rate=0.1),
|
||||
{"warmup_steps": 500, "total_steps": 400, "min_rate": 0.1},
|
||||
# Invalid min_rate
|
||||
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=-0.1),
|
||||
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=1.1),
|
||||
{"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1},
|
||||
{"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1},
|
||||
]
|
||||
|
||||
for config in invalid_configs:
|
||||
for kwargs in invalid_configs:
|
||||
with pytest.raises(ValueError):
|
||||
config = CosineScheduleConfig(**kwargs)
|
||||
config.validate()
|
||||
|
||||
|
||||
|
|
@ -117,7 +118,7 @@ def test_schedule_factory_state_persistence():
|
|||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
||||
|
||||
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1)
|
||||
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
|
||||
scheduler = SchedulerFactory.load(optimizer, config)
|
||||
|
||||
# Take a few steps
|
||||
for _ in range(5):
|
||||
|
|
@ -127,7 +128,7 @@ def test_schedule_factory_state_persistence():
|
|||
state_dict = scheduler.state_dict()
|
||||
|
||||
# Create new scheduler and load state
|
||||
new_scheduler = SchedulerFactory.load_scheduler(optimizer, config)
|
||||
new_scheduler = SchedulerFactory.load(optimizer, config)
|
||||
new_scheduler.load_state_dict(state_dict)
|
||||
|
||||
# Verify states match
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ import argparse
|
|||
import torch
|
||||
|
||||
from torch.optim import AdamW
|
||||
from khaosz.config import ParameterLoader, Checkpoint, TrainConfig, CosineScheduleConfig
|
||||
from khaosz.trainer import Trainer, StrategyFactory
|
||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||
from khaosz.trainer import Trainer, SchedulerFactory
|
||||
from khaosz.data import DatasetLoader
|
||||
|
||||
|
||||
|
|
@ -36,7 +36,6 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
|
||||
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
||||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
||||
parser.add_argument("--resume_from_checkpoint", action="store_true", help="Train from checkpoint or not.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -66,16 +65,12 @@ def train(
|
|||
pin_memory: bool,
|
||||
window_size: int,
|
||||
stride: int,
|
||||
resume_from_checkpoint: bool
|
||||
):
|
||||
assert train_type in ["seq", "sft", "dpo"]
|
||||
assert os.path.exists(param_path)
|
||||
|
||||
parameter = ParameterLoader.load(param_path)
|
||||
checkpoint = None
|
||||
|
||||
if isinstance(parameter, Checkpoint) and resume_from_checkpoint:
|
||||
checkpoint = parameter
|
||||
parameter = ModelParameter()
|
||||
parameter.load(param_path)
|
||||
|
||||
if window_size is None:
|
||||
window_size = parameter.config.m_len
|
||||
|
|
@ -91,13 +86,6 @@ def train(
|
|||
"pad_token_id": parameter.tokenizer.pad_id,
|
||||
}
|
||||
|
||||
strategy = StrategyFactory.load(
|
||||
model,
|
||||
train_type,
|
||||
device,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
dataset = DatasetLoader.load(
|
||||
train_type=train_type,
|
||||
load_path=data_root_path,
|
||||
|
|
@ -111,16 +99,25 @@ def train(
|
|||
{"params": [p for n, p in model.named_parameters() if "embed" not in n], "lr": max_lr}
|
||||
]
|
||||
|
||||
optim = AdamW(
|
||||
optimizer = AdamW(
|
||||
param_groups,
|
||||
betas=(adamw_beta1, adamw_beta2),
|
||||
weight_decay=adamw_weight_decay
|
||||
)
|
||||
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=warmup_steps,
|
||||
total_steps=len(dataset) * n_epoch // batch_size,
|
||||
)
|
||||
|
||||
scheduler = SchedulerFactory.load(optimizer, schedule_config)
|
||||
|
||||
train_config = TrainConfig(
|
||||
strategy=strategy,
|
||||
model=model,
|
||||
strategy=train_type,
|
||||
dataset=dataset,
|
||||
optimizer=optim,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
n_epoch=n_epoch,
|
||||
batch_size=batch_size,
|
||||
|
|
@ -134,17 +131,8 @@ def train(
|
|||
pin_memory=pin_memory
|
||||
)
|
||||
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=warmup_steps,
|
||||
total_steps=len(dataset) * n_epoch // batch_size,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
parameter=parameter,
|
||||
train_config=train_config,
|
||||
schedule_config=schedule_config,
|
||||
)
|
||||
trainer.train(checkpoint)
|
||||
trainer = Trainer(train_config)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue