refactor(trainer): 优化trainer 结构

This commit is contained in:
ViperEkura 2025-12-07 21:23:05 +08:00
parent 82e65ccc21
commit c98b175cd5
18 changed files with 314 additions and 424 deletions

View File

@ -4,7 +4,6 @@ __author__ = "ViperEkura"
from khaosz.api import Khaosz
from khaosz.config import (
ModelConfig,
ParameterLoader,
TrainConfig,
)
from khaosz.model.transformer import Transformer
@ -42,7 +41,6 @@ __all__ = [
"PriorityTextSplitter",
"ModelConfig",
"ParameterLoader",
"TrainConfig",
"DatasetLoader",

View File

@ -9,12 +9,13 @@ from khaosz.inference.generator import (
RetrievalGenerator,
EmbeddingEncoder
)
from khaosz.config.param_config import ParameterLoader
from khaosz.config.param_config import ModelParameter
class Khaosz:
def __init__(self, model_dir: str):
self.parameter = ParameterLoader.load(model_dir)
self.parameter = ModelParameter()
self.parameter.load(model_dir)
def to(self, *args, **kwargs):
self.parameter.to(*args, **kwargs)

View File

@ -1,5 +1,5 @@
from khaosz.config.model_config import ModelConfig
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
from khaosz.config.param_config import BaseModelIO, ModelParameter
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
from khaosz.config.train_config import TrainConfig
@ -7,8 +7,6 @@ from khaosz.config.train_config import TrainConfig
__all__ = [
"BaseModelIO",
"ModelParameter",
"Checkpoint",
"ParameterLoader",
"ModelConfig",
"TrainConfig",

View File

@ -1,11 +1,8 @@
import pickle as pkl
import matplotlib.pyplot as plt
import safetensors.torch as st
import torch.nn as nn
import torch.optim as optim
import safetensors.torch as st
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Self, Union
from typing import Optional, Self, Union
from pathlib import Path
from khaosz.data.tokenizer import BpeTokenizer
@ -63,7 +60,7 @@ class BaseModelIO:
return self
def to(self, *args, **kwargs) -> Self:
def to(self, *args, **kwargs) -> "BaseModelIO":
"""Move model to device."""
if self.model is not None:
self.model.to(*args, **kwargs)
@ -77,142 +74,6 @@ class ModelParameter(BaseModelIO):
def save(self, save_dir: Union[str, Path]):
self.save_components(save_dir)
def load(self, load_dir: Union[str, Path]) -> Self:
def load(self, load_dir: Union[str, Path]) -> "ModelParameter":
return self.load_components(load_dir)
@dataclass
class Checkpoint(BaseModelIO):
"""Extended model parameters with training state."""
optimizer_state: Dict[str, Any] = field(
default=None,
metadata={"help": "Optimizer state."}
)
scheduler_state: Dict[str, Any] = field(
default=None,
metadata={"help": "Sampler state."}
)
loss_list: List[float] = field(
default_factory=list,
metadata={"help": "List of training losses."}
)
epoch: int = field(
default=0,
metadata={"help": "Current epoch."}
)
batch_iter: int = field(
default=0,
metadata={"help": "Current iteration."}
)
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
dir_path = Path(directory)
return {
"loss_plot": dir_path / "loss_plot.png",
"training_state": dir_path / "training_state.pkl"
}
def to_dict(self) -> Dict[str, Any]:
return {
"optimizer_state": self.optimizer_state,
"scheduler_state": self.scheduler_state,
"epoch": self.epoch,
"batch_iter": self.batch_iter,
"loss_list": self.loss_list,
}
def from_dict(self, data: Dict[str, Any]) -> Self:
self.optimizer_state = data["optimizer_state"]
self.scheduler_state = data["scheduler_state"]
self.epoch = data["epoch"]
self.batch_iter = data["batch_iter"]
self.loss_list = data["loss_list"]
def save_training_state(self, save_dir: Union[str, Path]):
paths = self._get_training_paths(save_dir)
# Save loss plot
self._plot_loss(str(paths["loss_plot"]))
# Save training state
with open(str(paths["training_state"]), "wb") as f:
pkl.dump(self.to_dict(), f)
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
paths = self._get_training_paths(load_dir)
# Load training state
with open(str(paths["training_state"]), "rb") as f:
train_state = pkl.load(f)
self.from_dict(train_state)
return self
def _plot_loss(self, save_path: str):
"""Plot and save loss curve."""
if not self.loss_list:
return
batch_iter = len(self.loss_list)
plt.figure(figsize=(10, 6))
plt.plot(self.loss_list)
plt.title(f"Training Loss - Iteration {batch_iter}")
plt.xlabel("Batch")
plt.ylabel("Loss")
plt.grid(True)
plt.savefig(save_path, dpi=30, bbox_inches="tight")
plt.close()
def save(self, save_dir: Union[str, Path]):
"""Save complete checkpoint."""
self.save_components(save_dir)
self.save_training_state(save_dir)
def load(self, load_dir: Union[str, Path]) -> Self:
"""Load complete checkpoint."""
self.load_components(load_dir)
self.load_training_state(load_dir)
return self
class ParameterLoader:
"""Factory class for loading model parameters or checkpoints."""
@staticmethod
def load(load_dir: Union[str, Path]) -> Union[ModelParameter, Checkpoint]:
"""Load either ModelParameter or Checkpoint based on directory contents."""
load_dir = Path(load_dir)
# Check for training-specific files
loss_file = load_dir / "loss.pkl"
has_training_data = loss_file.exists()
# Create appropriate instance
if has_training_data:
checkpoint = Checkpoint()
checkpoint.load(str(load_dir))
return checkpoint
else:
params = ModelParameter()
params.load(str(load_dir))
return params
@staticmethod
def create_checkpoint(
model: nn.Module,
tokenizer: BpeTokenizer,
config: ModelConfig,
loss_list: Optional[list[float]] = None,
optimizer: Optional[optim.Optimizer] = None,
) -> Checkpoint:
"""Convenience method to create a training checkpoint."""
return Checkpoint(
model=model,
tokenizer=tokenizer,
config=config,
loss_list=loss_list or [],
optimizer_state=optimizer
)

View File

@ -1,4 +1,4 @@
from typing import Any, Literal, Dict
from typing import Any, Dict
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
@ -39,7 +39,10 @@ class CosineScheduleConfig(ScheduleConfig):
default=None,
metadata={"help": "Total training steps for cosine schedule."}
)
schedule_type: Literal["cosine"] = "cosine"
def __post_init__(self) -> None:
self.schedule_type = "cosine"
self.validate()
def get_kwargs(self) -> Dict[str, Any]:
if self.total_steps is None:
@ -68,7 +71,10 @@ class SGDRScheduleConfig(ScheduleConfig):
default=2,
metadata={"help": "Multiplier for cycle length growth."}
)
schedule_type: Literal["sgdr"] = "sgdr"
def __post_init__(self) -> None:
self.schedule_type = "sgdr"
self.validate()
def get_kwargs(self) -> Dict[str, Any]:
return {

View File

@ -1,15 +1,20 @@
from dataclasses import dataclass, field
from typing import Optional, TYPE_CHECKING
from torch import nn
from torch.utils.data import Dataset
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
if TYPE_CHECKING:
from khaosz.trainer.strategy import BaseStrategy
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class TrainConfig:
strategy: "BaseStrategy" = field(
# basic setting
model: nn.Module = field(
default=None,
metadata={"help": "Model for training."}
)
strategy: str = field(
default=None,
metadata={"help": "Training strategy."}
)
@ -21,9 +26,9 @@ class TrainConfig:
default=None,
metadata={"help": "Optimizer for training."}
)
checkpoint_dir: str = field(
default="./checkpoint",
metadata={"help": "Checkpoint directory."}
scheduler: LRScheduler = field(
default=None,
metadata={"help": "Scheduler for training."}
)
n_epoch: int = field(
default=1,
@ -33,6 +38,16 @@ class TrainConfig:
default=4,
metadata={"help": "Batch size for training."}
)
accumulation_steps: int = field(
default=1,
metadata={"help": "Number of iterations between steps."}
)
max_grad_norm: float = field(
default=1.0,
metadata={"help": "Maximum gradient norm."}
)
# checkpoint setting
start_epoch: int = field(
default=0,
metadata={"help": "Start epoch for training."}
@ -41,18 +56,14 @@ class TrainConfig:
default=0,
metadata={"help": "Start batch iteration for training."}
)
checkpoint_dir: str = field(
default="./checkpoint",
metadata={"help": "Checkpoint directory."}
)
checkpoint_interval: int = field(
default=5000,
metadata={"help": "Number of iterations between checkpoints."}
)
accumulation_steps: int = field(
default=1,
metadata={"help": "Number of iterations between steps."}
)
max_grad_norm: float = field(
default=1.0,
metadata={"help": "Maximum gradient norm."}
)
# dataloader setting
random_seed: int = field(
@ -77,3 +88,9 @@ class TrainConfig:
default=1,
metadata={"help": "Number of processes for distributed training."}
)
# others
kwargs: dict = field(
default_factory=dict,
metadata={"help": "Other arguments."}
)

View File

@ -0,0 +1,65 @@
import os
import pickle as pkl
import matplotlib.pyplot as plt
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from typing import Dict, Optional
class Checkpoint:
def __init__(
self,
optimizer_state: Optimizer,
scheduler_state: LRScheduler,
epoch: int = 0,
iteration: int = 0,
metrics: Optional[Dict[str, list]] = None,
):
self.optimizer_state = optimizer_state
self.scheduler_state = scheduler_state
self.epoch, self.iteration = epoch, iteration
self.metrics = metrics
def save(self, save_dir: str, save_metric_plot=True) -> None:
os.makedirs(save_dir, exist_ok=True)
train_state = {
"epoch": self.epoch,
"iteration": self.iteration,
"metrics": self.metrics,
"optimizer_state": self.optimizer_state,
"scheduler_state": self.scheduler_state,
}
with open(os.path.join(save_dir, "train_state.pkl"), "wb") as f:
pkl.dump(train_state, f)
if save_metric_plot and self.metrics:
self._plot_metrics()
def load(self, save_dir: str) -> "Checkpoint":
if not os.path.exists(save_dir):
raise FileNotFoundError(f"Checkpoint directory {save_dir} does not exist.")
with open(os.path.join(save_dir, "train_state.pkl"), "rb") as f:
train_state = pkl.load(f)
self.epoch = train_state["epoch"]
self.iteration = train_state["iteration"]
self.metrics = train_state["metrics"]
self.optimizer_state = train_state["optimizer_state"]
self.scheduler_state = train_state["scheduler_state"]
return self
def _plot_metrics(self):
for metric_name, metric_value in self.metrics.items():
plt.figure(figsize=(10, 6))
plt.plot(metric_value, label=metric_name)
plt.xlabel('Step')
plt.ylabel('Value')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(f'{metric_name}.png', dpi=150, bbox_inches='tight')
plt.close()

View File

@ -151,7 +151,7 @@ class SchedulerFactory:
"""
@staticmethod
def load_scheduler(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler:
def load(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler:
kwargs = scedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type")

View File

@ -4,7 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Any, Tuple, Callable, Dict, Union
from typing import Any, Callable, Dict, Union
from abc import ABC, abstractmethod
@ -41,7 +41,7 @@ class BaseStrategy(ABC):
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
raise NotImplementedError
def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor:
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
return self.compute_loss(batch)
@ -94,7 +94,7 @@ class DpoStrategy(BaseStrategy):
self.pad_token_id = pad_token_id
self.beta = beta
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
good_ids, bad_ids = batch["chosen"], batch["rejected"]
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
@ -115,41 +115,6 @@ class DpoStrategy(BaseStrategy):
return dpo_loss
class PpoStrategy(BaseStrategy):
def __init__(self, model, pad_token_id, epsilon):
super().__init__(model)
ref_model = copy.deepcopy(self.model)
ref_model.requires_grad_(False)
ref_model.eval()
self.ref_model = ref_model
self.pad_token_id = pad_token_id
self.epsilon = epsilon
def ppo_clip_loss_masked(
self,
log_probs: Tensor,
old_log_probs: Tensor,
advantages: Tensor,
values: Tensor,
returns: Tensor,
mask: Tensor,
clip_eps: float=0.2,
):
ratio = torch.exp(log_probs - old_log_probs)
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
policy_loss = -torch.min(surr1, surr2).masked_select(mask).mean()
value_loss = F.mse_loss(values.masked_select(mask),
returns.masked_select(mask))
entropy = -(log_probs.exp() * log_probs).masked_select(mask).mean()
entropy_loss = -entropy
return policy_loss, value_loss, entropy_loss
class StrategyFactory:
def load(model, train_type, device, **kwargs):

View File

@ -5,10 +5,9 @@ import time
from pathlib import Path
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import LRScheduler
from typing import List, Optional, Protocol, TYPE_CHECKING
from khaosz.config import ScheduleConfig
from khaosz.trainer.metric_util import (
grad_max,
grad_min,
@ -17,9 +16,9 @@ from khaosz.trainer.metric_util import (
grad_std,
grad_nan_num
)
from khaosz.trainer.checkpoint import Checkpoint
if TYPE_CHECKING:
from khaosz.trainer.trainer import Trainer
from khaosz.trainer.train_context import TrainContext
@ -28,31 +27,31 @@ class TrainCallback(Protocol):
Callback interface for trainer.
"""
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
def on_train_begin(self, context: 'TrainContext'):
""" Called at the beginning of training. """
def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'):
def on_train_end(self, context: 'TrainContext'):
""" Called at the end of training. """
def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'):
def on_epoch_begin(self, context: 'TrainContext'):
""" Called at the beginning of each epoch. """
def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'):
def on_epoch_end(self, context: 'TrainContext'):
""" Called at the end of each epoch. """
def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'):
def on_step_begin(self, context: 'TrainContext'):
""" Called at the beginning of each step. """
def on_step_end(self, trainer: 'Trainer', context: 'TrainContext'):
def on_step_end(self, context: 'TrainContext'):
""" Called at the end of each step."""
def on_batch_begin(self, trainer: 'Trainer', context: 'TrainContext'):
def on_batch_begin(self, context: 'TrainContext'):
""" Called at the beginning of each batch. """
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
def on_batch_end(self, context: 'TrainContext'):
""" Called at the end of each batch. """
def on_error(self, trainer: 'Trainer', context: 'TrainContext'):
def on_error(self, context: 'TrainContext'):
""" Called when an error occurs during training. """
@ -63,29 +62,27 @@ class GradientClippingCallback(TrainCallback):
def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm
def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'):
def on_step_begin(self, context: 'TrainContext'):
_ = context
clip_grad_norm_(trainer.parameter.model.parameters(), self.max_grad_norm)
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
class SchedulerCallback(TrainCallback):
"""
Scheduler callback for trainer.
"""
def __init__(self, schedule_config: ScheduleConfig):
self.schedule_config = schedule_config
self.scheduler: Optional[LambdaLR] = None
def __init__(self, scheduler: LRScheduler):
self.scheduler: LRScheduler = scheduler
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
for group in trainer.train_config.optimizer.param_groups:
def on_train_begin(self, context: 'TrainContext'):
for group in context.optimizer.param_groups:
if "initial_lr" not in group:
group["initial_lr"] = group["lr"]
self.scheduler = context.scheduler
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
_ = trainer, context
def on_batch_end(self, context: 'TrainContext'):
_ = context
if self.scheduler:
self.scheduler.step()
@ -94,54 +91,59 @@ class CheckpointCallback(TrainCallback):
"""
Checkpoint callback for trainer.
"""
def __init__(self, checkpoint_interval: int):
self.checkpoint_interval = checkpoint_interval
def __init__(self, interval: int, save_dir: str):
self.interval = interval
self.save_dir = save_dir
self.checkpoint = None
self.last_ckpt_iter = 0
def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'):
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.batch_iter}")
context.checkpoint.optimizer_state = context.optimizer.state_dict()
context.checkpoint.scheduler_state = context.scheduler.state_dict()
context.checkpoint.epoch = context.epoch
context.checkpoint.batch_iter = context.batch_iter
context.checkpoint.save(save_path)
self.last_ckpt_iter = context.batch_iter
def _save_checkpoint(self, context: 'TrainContext'):
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}")
self.checkpoint = Checkpoint(
context.optimizer.state_dict(),
context.scheduler.state_dict(),
context.epoch,
context.iteration
)
self.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
context.checkpoint.loss_list.append(context.loss)
def on_batch_end(self, context: 'TrainContext'):
if context.iteration - self.last_ckpt_iter >= self.interval:
self._save_checkpoint(context)
if context.batch_iter - self.last_ckpt_iter >= self.checkpoint_interval:
self._save_checkpoint(trainer, context)
def on_train_end(self, context: 'TrainContext'):
if context.iteration != self.last_ckpt_iter:
self._save_checkpoint(context)
def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'):
if context.batch_iter != self.last_ckpt_iter:
self._save_checkpoint(trainer, context)
def on_error(self, context: 'TrainContext'):
self._save_checkpoint(context)
class ProgressBarCallback(TrainCallback):
"""
Progress bar callback for trainer.
"""
def __init__(self):
def __init__(self, num_epoch: int):
self.num_epoch = num_epoch
self.progress_bar: tqdm = None
def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'):
def on_epoch_begin(self, context: 'TrainContext'):
self.progress_bar = tqdm(
context.dataloader,
desc=f"Epoch {context.epoch+1}/{trainer.train_config.n_epoch}",
desc=f"Epoch {context.epoch+1}/{self.num_epoch}",
dynamic_ncols=True
)
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
_ = trainer
def on_batch_end(self, context: 'TrainContext'):
self.progress_bar.set_postfix({
"loss": f"{context.loss:.4f}",
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}"
})
self.progress_bar.update(1)
def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'):
_ = trainer, context
def on_epoch_end(self, context: 'TrainContext'):
_ = context
if self.progress_bar:
self.progress_bar.close()
@ -177,13 +179,13 @@ class StepMonitorCallback(TrainCallback):
self.log_dir.mkdir(parents=True, exist_ok=True)
def _handle_info(self, trainer: 'Trainer', context: 'TrainContext'):
def _handle_info(self, context: 'TrainContext'):
""" Logs training information to console and file. """
log_data = {
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
"epoch": context.epoch,
"iter": context.batch_iter,
"iter": context.iteration,
"metrics": self.metrics,
}
@ -193,34 +195,34 @@ class StepMonitorCallback(TrainCallback):
elif metric == 'lr':
log_data[metric] = context.optimizer.param_groups[-1]['lr']
elif metric == 'grad_norm':
log_data[metric] = grad_norm(trainer.parameter.model)
log_data[metric] = grad_norm(context.model)
elif metric == 'grad_std':
log_data[metric] = grad_std(trainer.parameter.model)
log_data[metric] = grad_std(context.model)
elif metric == 'grad_max':
log_data[metric] = grad_max(trainer.parameter.model)
log_data[metric] = grad_max(context.model)
elif metric == 'grad_min':
log_data[metric] = grad_min(trainer.parameter.model)
log_data[metric] = grad_min(context.model)
elif metric == 'grad_mean':
log_data[metric] = grad_mean(trainer.parameter.model)
log_data[metric] = grad_mean(context.model)
elif metric == 'grad_nan_num':
log_data[metric] = grad_nan_num(trainer.parameter.model)
log_data[metric] = grad_nan_num(context.model)
else:
raise ValueError(f"Invalid metric: {metric}")
return log_data
def _handle_log(self, trainer: 'Trainer', context: 'TrainContext'):
def _handle_log(self, context: 'TrainContext'):
""" Logs training information to console and file. """
log_data = self._handle_info(trainer, context)
log_data = self._handle_info(context)
try:
log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.batch_iter}.json"
log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.iteration}.json"
with open(log_file, 'a') as f:
json.dump(log_data, f, indent=4)
except Exception:
raise
def on_step_end(self, trainer: 'Trainer', context: 'TrainContext'):
def on_step_end(self, context: 'TrainContext'):
if self.step_num % self.log_interval == 0:
self._handle_log(trainer, context)
self._handle_log(context)
self.step_num += 1

View File

@ -1,61 +1,60 @@
from dataclasses import dataclass, field, fields
from typing import Optional, Self, TYPE_CHECKING
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from khaosz.config import Checkpoint
from khaosz.data import ResumableDistributedSampler
from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory
from khaosz.parallel.utils import get_world_size, get_rank
if TYPE_CHECKING:
from khaosz.trainer.trainer import Trainer
from khaosz.data import ResumableDistributedSampler
from khaosz.trainer.checkpoint import Checkpoint
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
from khaosz.config.train_config import TrainConfig
from khaosz.parallel.utils import get_current_device, get_world_size, get_rank
from dataclasses import dataclass, field
from typing import Optional, Self
@dataclass
class TrainContext:
model: nn.Module = field(default=None)
strategy: BaseStrategy = field(default=None)
dataloader: DataLoader = field(default=None)
optimizer: Optimizer = field(default=None)
scheduler: BaseScheduler = field(default=None)
scheduler: LRScheduler = field(default=None)
checkpoint: Checkpoint = field(default=None)
epoch: int = field(default=0)
batch_iter: int = field(default=0)
iteration: int = field(default=0)
loss: float = field(default=0.0)
wolrd_size: int = field(default=1)
rank: int = field(default=0)
def asdict(self) -> dict:
return {field.name: getattr(self, field.name)
for field in fields(self)}
class TrainContextBuilder:
def __init__(self, trainer: 'Trainer'):
self.trainer = trainer
def __init__(self, config: TrainConfig):
self.config = config
self._context: TrainContext = None
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._context = TrainContext()
if checkpoint is None:
checkpoint = Checkpoint(
model=self.trainer.parameter.model,
tokenizer=self.trainer.parameter.tokenizer,
config=self.trainer.parameter.config,
optimizer_state=self.config.optimizer.state_dict(),
scheduler_state=self.config.scheduler.state_dict(),
)
else:
# resume from the assigned checkpoint or assigned iteration
self._context.epoch = max(checkpoint.epoch, self.trainer.train_config.start_epoch)
self._context.batch_iter = max(checkpoint.batch_iter, self.trainer.train_config.start_batch)
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
self._context.checkpoint = checkpoint
return self
def with_optimizer(self) -> Self:
optimizer = self.config.optimizer
if self._context is None:
raise RuntimeError("Must call with_checkpoint() before with_optimizer()")
optimizer = self.trainer.train_config.optimizer
if self._context.checkpoint and self._context.checkpoint.optimizer_state:
optimizer.load_state_dict(self._context.checkpoint.optimizer_state)
@ -67,13 +66,7 @@ class TrainContextBuilder:
return self
def with_scheduler(self) -> Self:
if not hasattr(self._context, 'optimizer') or self._context.optimizer is None:
raise RuntimeError("Must call with_optimizer() before with_scheduler()")
optimizer = self.trainer.train_config.optimizer
schedule_config = self.trainer.schedule_config
scheduler = SchedulerFactory.load_scheduler(optimizer, schedule_config)
scheduler = self.config.scheduler
if self._context.checkpoint and self._context.checkpoint.scheduler_state:
scheduler.load_state_dict(self._context.checkpoint.scheduler_state)
@ -85,29 +78,41 @@ class TrainContextBuilder:
return self
def with_dataloader(self) -> Self:
# fix: change batch level batch_iter to sample level offset
sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size
# fix: change batch level iteration to sample level offset
config = self.config
sampler_offset = self._context.iteration * config.batch_size
resumeable_sampler = ResumableDistributedSampler(
data_source=self.trainer.train_config.dataset,
data_source=config.dataset,
start_epoch=self._context.epoch,
start_iter=sampler_offset,
seed=self.trainer.train_config.random_seed
seed=config.random_seed
)
dataloader = DataLoader(
self.trainer.train_config.dataset,
batch_size=self.trainer.train_config.batch_size,
config.dataset,
batch_size=config.batch_size,
sampler=resumeable_sampler,
num_workers=self.trainer.train_config.num_workers,
pin_memory=self.trainer.train_config.pin_memory,
prefetch_factor=self.trainer.train_config.prefetch_factor
num_workers=config.num_workers,
pin_memory=config.pin_memory,
prefetch_factor=config.prefetch_factor
)
self._context.dataloader = dataloader
return self
def build(self) -> TrainContext:
def with_strategy(self) -> Self:
device = get_current_device()
self._context.strategy = StrategyFactory.load(
model=self.config.model,
train_type=self.config.strategy,
device=device,
**self.config.kwargs
)
return self
if self.trainer.train_config.nprocs > 1:
def build(self) -> TrainContext:
self._context.model = self.config.model
if self.config.nprocs > 1:
self._context.wolrd_size = get_world_size()
self._context.rank = get_rank()

View File

@ -1,11 +1,6 @@
import logging
from typing import Optional, List
from khaosz.config import (
ModelParameter,
Checkpoint,
ScheduleConfig,
TrainConfig
)
from khaosz.config import TrainConfig
from khaosz.trainer.train_callback import (
TrainCallback,
ProgressBarCallback,
@ -13,7 +8,7 @@ from khaosz.trainer.train_callback import (
GradientClippingCallback,
SchedulerCallback
)
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder, Checkpoint
logger = logging.getLogger(__name__)
@ -21,66 +16,65 @@ logger = logging.getLogger(__name__)
class Trainer:
def __init__(
self,
parameter: ModelParameter,
train_config: TrainConfig,
schedule_config: ScheduleConfig,
callbacks: Optional[List[TrainCallback]] = None
):
self.parameter = parameter
self.train_config = train_config
self.schedule_config = schedule_config
self.callbacks = callbacks or self._get_default_callbacks()
def _get_default_callbacks(self) -> List[TrainCallback]:
train_config = self.train_config
return [
ProgressBarCallback(),
CheckpointCallback(self.train_config.checkpoint_interval),
GradientClippingCallback(self.train_config.max_grad_norm),
SchedulerCallback(self.schedule_config),
ProgressBarCallback(train_config.n_epoch),
CheckpointCallback(train_config.checkpoint_interval, train_config.checkpoint_dir),
GradientClippingCallback(train_config.max_grad_norm),
SchedulerCallback(train_config.scheduler),
]
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (TrainContextBuilder(self)
return (TrainContextBuilder(self.train_config)
.with_checkpoint(checkpoint)
.with_optimizer()
.with_scheduler()
.with_dataloader()
.with_strategy()
.build())
def _call_callbacks(self, method_name: str, context: TrainContext):
for callback in self.callbacks:
method = getattr(callback, method_name, None)
if method:
method(self, context)
method(context)
def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
context = self._build_context(checkpoint)
self._call_callbacks('on_train_begin', context)
try:
self.parameter.model.train()
context.model.train()
# 1.epoch
for epoch in range(context.epoch, self.train_config.n_epoch):
context.epoch = epoch
self._call_callbacks('on_epoch_begin', context)
for batch in context.dataloader:
if context.batch_iter % self.train_config.accumulation_steps == 0:
if context.iteration % self.train_config.accumulation_steps == 0:
# 2. step
self._call_callbacks('on_step_begin', context)
self.train_config.optimizer.step()
self.train_config.optimizer.zero_grad()
context.optimizer.step()
context.optimizer.zero_grad()
self._call_callbacks('on_step_end', context)
# 3. batch
self._call_callbacks('on_batch_begin', context)
loss = self.train_config.strategy(batch)
loss = context.strategy(batch)
context.loss = loss.item()
context.batch_iter += 1
context.iteration += 1
# to make the loss normalized by accumulation steps
normalized_loss = loss / self.train_config.accumulation_steps
normalized_loss.backward()
stand_batch = self.train_config.accumulation_steps * self.train_config.nprocs
stand_loss = loss / stand_batch
stand_loss.backward()
self._call_callbacks('on_batch_end', context)

View File

@ -5,10 +5,20 @@ from khaosz.trainer import *
def test_callback_integration(base_test_env, random_dataset):
"""Test that all callbacks are properly integrated"""
schedule_config = CosineScheduleConfig(
warmup_steps=10,
total_steps=20
)
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
scheduler = SchedulerFactory.load(optimizer, schedule_config)
train_config = TrainConfig(
model=base_test_env["model"],
strategy='seq',
dataset=random_dataset,
optimizer=optimizer,
scheduler=scheduler,
checkpoint_dir=base_test_env["test_dir"],
n_epoch=1,
batch_size=2,
@ -18,36 +28,26 @@ def test_callback_integration(base_test_env, random_dataset):
random_seed=42
)
schedule_config = CosineScheduleConfig(
warmup_steps=10,
total_steps=20
)
# Create custom callbacks to track calls
callback_calls = []
class TrackingCallback(TrainCallback):
def on_train_begin(self, trainer, context):
def on_train_begin(self, context):
callback_calls.append('on_train_begin')
def on_batch_end(self, trainer, context):
def on_batch_end(self, context):
callback_calls.append('on_batch_end')
def on_epoch_end(self, trainer, context):
def on_epoch_end(self, context):
callback_calls.append('on_epoch_end')
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
model_parameter = ModelParameter(
base_test_env["model"],
base_test_env["tokenizer"],
base_test_env["transformer_config"]
)
trainer = Trainer(
model_parameter,
train_config,
schedule_config,
callbacks=[TrackingCallback(), ProgressBarCallback()]
callbacks=[TrackingCallback()]
)
trainer.train()

View File

@ -7,35 +7,34 @@ from khaosz.trainer import *
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
"""Simulate early stopping behavior"""
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
scheduler = SchedulerFactory.load(optimizer, schedule_config)
train_config = TrainConfig(
strategy="seq",
scheduler=scheduler,
model=base_test_env["model"],
dataset=early_stopping_dataset,
optimizer=optimizer,
checkpoint_dir=base_test_env["test_dir"],
n_epoch=2,
batch_size=2,
checkpoint_interval=2,
checkpoint_interval=1,
accumulation_steps=2,
random_seed=np.random.randint(1e4),
)
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
model_parameter = ModelParameter(
base_test_env["model"],
base_test_env["tokenizer"],
base_test_env["transformer_config"]
)
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
trainer = Trainer(model_parameter, train_config, schedule_config)
trainer = Trainer(train_config)
# Should handle early stopping gracefully
checkpoint = None
try:
checkpoint = trainer.train()
assert len(checkpoint.loss_list) == 2
assert checkpoint.iteration == 2
except Exception:
# Handle any exceptions
pass
checkpoint = trainer.train(checkpoint)
assert len(checkpoint.loss_list) == 10
assert checkpoint.iteration == 10

View File

@ -51,13 +51,6 @@ def test_env(request: pytest.FixtureRequest):
shutil.rmtree(test_dir)
# parameter loader
def test_parameter_loader(test_env):
loaded_param = ParameterLoader.load(test_env["test_dir"])
assert loaded_param.model is not None
assert loaded_param.tokenizer is not None
assert loaded_param.config == test_env["transformer_config"]
def test_model_parameter(test_env):
save_dir = os.path.join(test_env["test_dir"], "save")
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])

View File

@ -31,10 +31,18 @@ def test_gradient_accumulation(base_test_env, random_dataset):
accumulation_steps_list = [1, 2, 4]
for accumulation_steps in accumulation_steps_list:
schedule_config = CosineScheduleConfig(
warmup_steps=10,
total_steps=20
)
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
scheduler = SchedulerFactory.load(optimizer, schedule_config)
train_config = TrainConfig(
dataset=random_dataset,
strategy="seq",
model=base_test_env["model"],
optimizer=optimizer,
scheduler=scheduler,
dataset=random_dataset,
checkpoint_dir=base_test_env["test_dir"],
n_epoch=1,
batch_size=2,
@ -44,18 +52,7 @@ def test_gradient_accumulation(base_test_env, random_dataset):
random_seed=42
)
schedule_config = CosineScheduleConfig(
warmup_steps=10,
total_steps=20
)
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
model_parameter = ModelParameter(
base_test_env["model"],
base_test_env["tokenizer"],
base_test_env["transformer_config"]
)
trainer = Trainer(model_parameter, train_config, schedule_config)
trainer = Trainer(train_config)
trainer.train()
assert train_config.accumulation_steps == accumulation_steps

View File

@ -35,7 +35,7 @@ def test_schedule_factory_random_configs():
config.validate()
# Create scheduler using factory
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
scheduler = SchedulerFactory.load(optimizer, config)
# Verify scheduler type
if isinstance(config, CosineScheduleConfig):
@ -83,7 +83,7 @@ def test_schedule_factory_edge_cases():
for config in edge_cases:
config.validate()
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
scheduler = SchedulerFactory.load(optimizer, config)
assert scheduler is not None
# Test multiple steps
@ -97,16 +97,17 @@ def test_schedule_factory_invalid_configs():
# Test invalid configurations that should raise errors
invalid_configs = [
# Negative warmup steps
CosineScheduleConfig(warmup_steps=-10, total_steps=1000, min_rate=0.1),
{"warmup_steps": -10, "total_steps": 1000, "min_rate": 0.1},
# Total steps less than warmup steps
CosineScheduleConfig(warmup_steps=500, total_steps=400, min_rate=0.1),
{"warmup_steps": 500, "total_steps": 400, "min_rate": 0.1},
# Invalid min_rate
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=-0.1),
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=1.1),
{"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1},
{"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1},
]
for config in invalid_configs:
for kwargs in invalid_configs:
with pytest.raises(ValueError):
config = CosineScheduleConfig(**kwargs)
config.validate()
@ -117,7 +118,7 @@ def test_schedule_factory_state_persistence():
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1)
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
scheduler = SchedulerFactory.load(optimizer, config)
# Take a few steps
for _ in range(5):
@ -127,7 +128,7 @@ def test_schedule_factory_state_persistence():
state_dict = scheduler.state_dict()
# Create new scheduler and load state
new_scheduler = SchedulerFactory.load_scheduler(optimizer, config)
new_scheduler = SchedulerFactory.load(optimizer, config)
new_scheduler.load_state_dict(state_dict)
# Verify states match

View File

@ -3,8 +3,8 @@ import argparse
import torch
from torch.optim import AdamW
from khaosz.config import ParameterLoader, Checkpoint, TrainConfig, CosineScheduleConfig
from khaosz.trainer import Trainer, StrategyFactory
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
from khaosz.trainer import Trainer, SchedulerFactory
from khaosz.data import DatasetLoader
@ -36,7 +36,6 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
parser.add_argument("--resume_from_checkpoint", action="store_true", help="Train from checkpoint or not.")
args = parser.parse_args()
@ -66,16 +65,12 @@ def train(
pin_memory: bool,
window_size: int,
stride: int,
resume_from_checkpoint: bool
):
assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path)
parameter = ParameterLoader.load(param_path)
checkpoint = None
if isinstance(parameter, Checkpoint) and resume_from_checkpoint:
checkpoint = parameter
parameter = ModelParameter()
parameter.load(param_path)
if window_size is None:
window_size = parameter.config.m_len
@ -91,13 +86,6 @@ def train(
"pad_token_id": parameter.tokenizer.pad_id,
}
strategy = StrategyFactory.load(
model,
train_type,
device,
**kwargs
)
dataset = DatasetLoader.load(
train_type=train_type,
load_path=data_root_path,
@ -111,16 +99,25 @@ def train(
{"params": [p for n, p in model.named_parameters() if "embed" not in n], "lr": max_lr}
]
optim = AdamW(
optimizer = AdamW(
param_groups,
betas=(adamw_beta1, adamw_beta2),
weight_decay=adamw_weight_decay
)
schedule_config = CosineScheduleConfig(
warmup_steps=warmup_steps,
total_steps=len(dataset) * n_epoch // batch_size,
)
scheduler = SchedulerFactory.load(optimizer, schedule_config)
train_config = TrainConfig(
strategy=strategy,
model=model,
strategy=train_type,
dataset=dataset,
optimizer=optim,
optimizer=optimizer,
scheduler=scheduler,
checkpoint_dir=checkpoint_dir,
n_epoch=n_epoch,
batch_size=batch_size,
@ -134,17 +131,8 @@ def train(
pin_memory=pin_memory
)
schedule_config = CosineScheduleConfig(
warmup_steps=warmup_steps,
total_steps=len(dataset) * n_epoch // batch_size,
)
trainer = Trainer(
parameter=parameter,
train_config=train_config,
schedule_config=schedule_config,
)
trainer.train(checkpoint)
trainer = Trainer(train_config)
trainer.train()
if __name__ == "__main__":