refactor(trainer): 优化trainer 结构

This commit is contained in:
ViperEkura 2025-12-07 21:23:05 +08:00
parent 82e65ccc21
commit c98b175cd5
18 changed files with 314 additions and 424 deletions

View File

@ -4,7 +4,6 @@ __author__ = "ViperEkura"
from khaosz.api import Khaosz from khaosz.api import Khaosz
from khaosz.config import ( from khaosz.config import (
ModelConfig, ModelConfig,
ParameterLoader,
TrainConfig, TrainConfig,
) )
from khaosz.model.transformer import Transformer from khaosz.model.transformer import Transformer
@ -42,7 +41,6 @@ __all__ = [
"PriorityTextSplitter", "PriorityTextSplitter",
"ModelConfig", "ModelConfig",
"ParameterLoader",
"TrainConfig", "TrainConfig",
"DatasetLoader", "DatasetLoader",

View File

@ -9,12 +9,13 @@ from khaosz.inference.generator import (
RetrievalGenerator, RetrievalGenerator,
EmbeddingEncoder EmbeddingEncoder
) )
from khaosz.config.param_config import ParameterLoader from khaosz.config.param_config import ModelParameter
class Khaosz: class Khaosz:
def __init__(self, model_dir: str): def __init__(self, model_dir: str):
self.parameter = ParameterLoader.load(model_dir) self.parameter = ModelParameter()
self.parameter.load(model_dir)
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
self.parameter.to(*args, **kwargs) self.parameter.to(*args, **kwargs)

View File

@ -1,5 +1,5 @@
from khaosz.config.model_config import ModelConfig from khaosz.config.model_config import ModelConfig
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader from khaosz.config.param_config import BaseModelIO, ModelParameter
from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig
from khaosz.config.train_config import TrainConfig from khaosz.config.train_config import TrainConfig
@ -7,8 +7,6 @@ from khaosz.config.train_config import TrainConfig
__all__ = [ __all__ = [
"BaseModelIO", "BaseModelIO",
"ModelParameter", "ModelParameter",
"Checkpoint",
"ParameterLoader",
"ModelConfig", "ModelConfig",
"TrainConfig", "TrainConfig",

View File

@ -1,11 +1,8 @@
import pickle as pkl
import matplotlib.pyplot as plt
import safetensors.torch as st
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import safetensors.torch as st
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Self, Union from typing import Optional, Self, Union
from pathlib import Path from pathlib import Path
from khaosz.data.tokenizer import BpeTokenizer from khaosz.data.tokenizer import BpeTokenizer
@ -63,7 +60,7 @@ class BaseModelIO:
return self return self
def to(self, *args, **kwargs) -> Self: def to(self, *args, **kwargs) -> "BaseModelIO":
"""Move model to device.""" """Move model to device."""
if self.model is not None: if self.model is not None:
self.model.to(*args, **kwargs) self.model.to(*args, **kwargs)
@ -77,142 +74,6 @@ class ModelParameter(BaseModelIO):
def save(self, save_dir: Union[str, Path]): def save(self, save_dir: Union[str, Path]):
self.save_components(save_dir) self.save_components(save_dir)
def load(self, load_dir: Union[str, Path]) -> Self: def load(self, load_dir: Union[str, Path]) -> "ModelParameter":
return self.load_components(load_dir) return self.load_components(load_dir)
@dataclass
class Checkpoint(BaseModelIO):
"""Extended model parameters with training state."""
optimizer_state: Dict[str, Any] = field(
default=None,
metadata={"help": "Optimizer state."}
)
scheduler_state: Dict[str, Any] = field(
default=None,
metadata={"help": "Sampler state."}
)
loss_list: List[float] = field(
default_factory=list,
metadata={"help": "List of training losses."}
)
epoch: int = field(
default=0,
metadata={"help": "Current epoch."}
)
batch_iter: int = field(
default=0,
metadata={"help": "Current iteration."}
)
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
dir_path = Path(directory)
return {
"loss_plot": dir_path / "loss_plot.png",
"training_state": dir_path / "training_state.pkl"
}
def to_dict(self) -> Dict[str, Any]:
return {
"optimizer_state": self.optimizer_state,
"scheduler_state": self.scheduler_state,
"epoch": self.epoch,
"batch_iter": self.batch_iter,
"loss_list": self.loss_list,
}
def from_dict(self, data: Dict[str, Any]) -> Self:
self.optimizer_state = data["optimizer_state"]
self.scheduler_state = data["scheduler_state"]
self.epoch = data["epoch"]
self.batch_iter = data["batch_iter"]
self.loss_list = data["loss_list"]
def save_training_state(self, save_dir: Union[str, Path]):
paths = self._get_training_paths(save_dir)
# Save loss plot
self._plot_loss(str(paths["loss_plot"]))
# Save training state
with open(str(paths["training_state"]), "wb") as f:
pkl.dump(self.to_dict(), f)
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
paths = self._get_training_paths(load_dir)
# Load training state
with open(str(paths["training_state"]), "rb") as f:
train_state = pkl.load(f)
self.from_dict(train_state)
return self
def _plot_loss(self, save_path: str):
"""Plot and save loss curve."""
if not self.loss_list:
return
batch_iter = len(self.loss_list)
plt.figure(figsize=(10, 6))
plt.plot(self.loss_list)
plt.title(f"Training Loss - Iteration {batch_iter}")
plt.xlabel("Batch")
plt.ylabel("Loss")
plt.grid(True)
plt.savefig(save_path, dpi=30, bbox_inches="tight")
plt.close()
def save(self, save_dir: Union[str, Path]):
"""Save complete checkpoint."""
self.save_components(save_dir)
self.save_training_state(save_dir)
def load(self, load_dir: Union[str, Path]) -> Self:
"""Load complete checkpoint."""
self.load_components(load_dir)
self.load_training_state(load_dir)
return self
class ParameterLoader:
"""Factory class for loading model parameters or checkpoints."""
@staticmethod
def load(load_dir: Union[str, Path]) -> Union[ModelParameter, Checkpoint]:
"""Load either ModelParameter or Checkpoint based on directory contents."""
load_dir = Path(load_dir)
# Check for training-specific files
loss_file = load_dir / "loss.pkl"
has_training_data = loss_file.exists()
# Create appropriate instance
if has_training_data:
checkpoint = Checkpoint()
checkpoint.load(str(load_dir))
return checkpoint
else:
params = ModelParameter()
params.load(str(load_dir))
return params
@staticmethod
def create_checkpoint(
model: nn.Module,
tokenizer: BpeTokenizer,
config: ModelConfig,
loss_list: Optional[list[float]] = None,
optimizer: Optional[optim.Optimizer] = None,
) -> Checkpoint:
"""Convenience method to create a training checkpoint."""
return Checkpoint(
model=model,
tokenizer=tokenizer,
config=config,
loss_list=loss_list or [],
optimizer_state=optimizer
)

View File

@ -1,4 +1,4 @@
from typing import Any, Literal, Dict from typing import Any, Dict
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -39,7 +39,10 @@ class CosineScheduleConfig(ScheduleConfig):
default=None, default=None,
metadata={"help": "Total training steps for cosine schedule."} metadata={"help": "Total training steps for cosine schedule."}
) )
schedule_type: Literal["cosine"] = "cosine"
def __post_init__(self) -> None:
self.schedule_type = "cosine"
self.validate()
def get_kwargs(self) -> Dict[str, Any]: def get_kwargs(self) -> Dict[str, Any]:
if self.total_steps is None: if self.total_steps is None:
@ -68,7 +71,10 @@ class SGDRScheduleConfig(ScheduleConfig):
default=2, default=2,
metadata={"help": "Multiplier for cycle length growth."} metadata={"help": "Multiplier for cycle length growth."}
) )
schedule_type: Literal["sgdr"] = "sgdr"
def __post_init__(self) -> None:
self.schedule_type = "sgdr"
self.validate()
def get_kwargs(self) -> Dict[str, Any]: def get_kwargs(self) -> Dict[str, Any]:
return { return {

View File

@ -1,15 +1,20 @@
from dataclasses import dataclass, field from torch import nn
from typing import Optional, TYPE_CHECKING
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
if TYPE_CHECKING: from dataclasses import dataclass, field
from khaosz.trainer.strategy import BaseStrategy from typing import Optional
@dataclass @dataclass
class TrainConfig: class TrainConfig:
strategy: "BaseStrategy" = field( # basic setting
model: nn.Module = field(
default=None,
metadata={"help": "Model for training."}
)
strategy: str = field(
default=None, default=None,
metadata={"help": "Training strategy."} metadata={"help": "Training strategy."}
) )
@ -21,9 +26,9 @@ class TrainConfig:
default=None, default=None,
metadata={"help": "Optimizer for training."} metadata={"help": "Optimizer for training."}
) )
checkpoint_dir: str = field( scheduler: LRScheduler = field(
default="./checkpoint", default=None,
metadata={"help": "Checkpoint directory."} metadata={"help": "Scheduler for training."}
) )
n_epoch: int = field( n_epoch: int = field(
default=1, default=1,
@ -33,6 +38,16 @@ class TrainConfig:
default=4, default=4,
metadata={"help": "Batch size for training."} metadata={"help": "Batch size for training."}
) )
accumulation_steps: int = field(
default=1,
metadata={"help": "Number of iterations between steps."}
)
max_grad_norm: float = field(
default=1.0,
metadata={"help": "Maximum gradient norm."}
)
# checkpoint setting
start_epoch: int = field( start_epoch: int = field(
default=0, default=0,
metadata={"help": "Start epoch for training."} metadata={"help": "Start epoch for training."}
@ -41,18 +56,14 @@ class TrainConfig:
default=0, default=0,
metadata={"help": "Start batch iteration for training."} metadata={"help": "Start batch iteration for training."}
) )
checkpoint_dir: str = field(
default="./checkpoint",
metadata={"help": "Checkpoint directory."}
)
checkpoint_interval: int = field( checkpoint_interval: int = field(
default=5000, default=5000,
metadata={"help": "Number of iterations between checkpoints."} metadata={"help": "Number of iterations between checkpoints."}
) )
accumulation_steps: int = field(
default=1,
metadata={"help": "Number of iterations between steps."}
)
max_grad_norm: float = field(
default=1.0,
metadata={"help": "Maximum gradient norm."}
)
# dataloader setting # dataloader setting
random_seed: int = field( random_seed: int = field(
@ -77,3 +88,9 @@ class TrainConfig:
default=1, default=1,
metadata={"help": "Number of processes for distributed training."} metadata={"help": "Number of processes for distributed training."}
) )
# others
kwargs: dict = field(
default_factory=dict,
metadata={"help": "Other arguments."}
)

View File

@ -0,0 +1,65 @@
import os
import pickle as pkl
import matplotlib.pyplot as plt
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from typing import Dict, Optional
class Checkpoint:
def __init__(
self,
optimizer_state: Optimizer,
scheduler_state: LRScheduler,
epoch: int = 0,
iteration: int = 0,
metrics: Optional[Dict[str, list]] = None,
):
self.optimizer_state = optimizer_state
self.scheduler_state = scheduler_state
self.epoch, self.iteration = epoch, iteration
self.metrics = metrics
def save(self, save_dir: str, save_metric_plot=True) -> None:
os.makedirs(save_dir, exist_ok=True)
train_state = {
"epoch": self.epoch,
"iteration": self.iteration,
"metrics": self.metrics,
"optimizer_state": self.optimizer_state,
"scheduler_state": self.scheduler_state,
}
with open(os.path.join(save_dir, "train_state.pkl"), "wb") as f:
pkl.dump(train_state, f)
if save_metric_plot and self.metrics:
self._plot_metrics()
def load(self, save_dir: str) -> "Checkpoint":
if not os.path.exists(save_dir):
raise FileNotFoundError(f"Checkpoint directory {save_dir} does not exist.")
with open(os.path.join(save_dir, "train_state.pkl"), "rb") as f:
train_state = pkl.load(f)
self.epoch = train_state["epoch"]
self.iteration = train_state["iteration"]
self.metrics = train_state["metrics"]
self.optimizer_state = train_state["optimizer_state"]
self.scheduler_state = train_state["scheduler_state"]
return self
def _plot_metrics(self):
for metric_name, metric_value in self.metrics.items():
plt.figure(figsize=(10, 6))
plt.plot(metric_value, label=metric_name)
plt.xlabel('Step')
plt.ylabel('Value')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(f'{metric_name}.png', dpi=150, bbox_inches='tight')
plt.close()

View File

@ -151,7 +151,7 @@ class SchedulerFactory:
""" """
@staticmethod @staticmethod
def load_scheduler(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler: def load(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler:
kwargs = scedule_config.get_kwargs() kwargs = scedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type") schedule_type = kwargs.pop("schedule_type")

View File

@ -4,7 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from typing import Any, Tuple, Callable, Dict, Union from typing import Any, Callable, Dict, Union
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -41,7 +41,7 @@ class BaseStrategy(ABC):
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
raise NotImplementedError raise NotImplementedError
def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor: def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
return self.compute_loss(batch) return self.compute_loss(batch)
@ -94,7 +94,7 @@ class DpoStrategy(BaseStrategy):
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.beta = beta self.beta = beta
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device) batch = move_to_device(batch, self.device)
good_ids, bad_ids = batch["chosen"], batch["rejected"] good_ids, bad_ids = batch["chosen"], batch["rejected"]
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"] good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
@ -115,41 +115,6 @@ class DpoStrategy(BaseStrategy):
return dpo_loss return dpo_loss
class PpoStrategy(BaseStrategy):
def __init__(self, model, pad_token_id, epsilon):
super().__init__(model)
ref_model = copy.deepcopy(self.model)
ref_model.requires_grad_(False)
ref_model.eval()
self.ref_model = ref_model
self.pad_token_id = pad_token_id
self.epsilon = epsilon
def ppo_clip_loss_masked(
self,
log_probs: Tensor,
old_log_probs: Tensor,
advantages: Tensor,
values: Tensor,
returns: Tensor,
mask: Tensor,
clip_eps: float=0.2,
):
ratio = torch.exp(log_probs - old_log_probs)
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
policy_loss = -torch.min(surr1, surr2).masked_select(mask).mean()
value_loss = F.mse_loss(values.masked_select(mask),
returns.masked_select(mask))
entropy = -(log_probs.exp() * log_probs).masked_select(mask).mean()
entropy_loss = -entropy
return policy_loss, value_loss, entropy_loss
class StrategyFactory: class StrategyFactory:
def load(model, train_type, device, **kwargs): def load(model, train_type, device, **kwargs):

View File

@ -5,10 +5,9 @@ import time
from pathlib import Path from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LRScheduler
from typing import List, Optional, Protocol, TYPE_CHECKING from typing import List, Optional, Protocol, TYPE_CHECKING
from khaosz.config import ScheduleConfig
from khaosz.trainer.metric_util import ( from khaosz.trainer.metric_util import (
grad_max, grad_max,
grad_min, grad_min,
@ -17,9 +16,9 @@ from khaosz.trainer.metric_util import (
grad_std, grad_std,
grad_nan_num grad_nan_num
) )
from khaosz.trainer.checkpoint import Checkpoint
if TYPE_CHECKING: if TYPE_CHECKING:
from khaosz.trainer.trainer import Trainer
from khaosz.trainer.train_context import TrainContext from khaosz.trainer.train_context import TrainContext
@ -28,31 +27,31 @@ class TrainCallback(Protocol):
Callback interface for trainer. Callback interface for trainer.
""" """
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): def on_train_begin(self, context: 'TrainContext'):
""" Called at the beginning of training. """ """ Called at the beginning of training. """
def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'): def on_train_end(self, context: 'TrainContext'):
""" Called at the end of training. """ """ Called at the end of training. """
def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'): def on_epoch_begin(self, context: 'TrainContext'):
""" Called at the beginning of each epoch. """ """ Called at the beginning of each epoch. """
def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'): def on_epoch_end(self, context: 'TrainContext'):
""" Called at the end of each epoch. """ """ Called at the end of each epoch. """
def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'): def on_step_begin(self, context: 'TrainContext'):
""" Called at the beginning of each step. """ """ Called at the beginning of each step. """
def on_step_end(self, trainer: 'Trainer', context: 'TrainContext'): def on_step_end(self, context: 'TrainContext'):
""" Called at the end of each step.""" """ Called at the end of each step."""
def on_batch_begin(self, trainer: 'Trainer', context: 'TrainContext'): def on_batch_begin(self, context: 'TrainContext'):
""" Called at the beginning of each batch. """ """ Called at the beginning of each batch. """
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): def on_batch_end(self, context: 'TrainContext'):
""" Called at the end of each batch. """ """ Called at the end of each batch. """
def on_error(self, trainer: 'Trainer', context: 'TrainContext'): def on_error(self, context: 'TrainContext'):
""" Called when an error occurs during training. """ """ Called when an error occurs during training. """
@ -63,29 +62,27 @@ class GradientClippingCallback(TrainCallback):
def __init__(self, max_grad_norm: float): def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'): def on_step_begin(self, context: 'TrainContext'):
_ = context _ = context
clip_grad_norm_(trainer.parameter.model.parameters(), self.max_grad_norm) clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
class SchedulerCallback(TrainCallback): class SchedulerCallback(TrainCallback):
""" """
Scheduler callback for trainer. Scheduler callback for trainer.
""" """
def __init__(self, schedule_config: ScheduleConfig): def __init__(self, scheduler: LRScheduler):
self.schedule_config = schedule_config self.scheduler: LRScheduler = scheduler
self.scheduler: Optional[LambdaLR] = None
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): def on_train_begin(self, context: 'TrainContext'):
for group in context.optimizer.param_groups:
for group in trainer.train_config.optimizer.param_groups:
if "initial_lr" not in group: if "initial_lr" not in group:
group["initial_lr"] = group["lr"] group["initial_lr"] = group["lr"]
self.scheduler = context.scheduler self.scheduler = context.scheduler
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): def on_batch_end(self, context: 'TrainContext'):
_ = trainer, context _ = context
if self.scheduler: if self.scheduler:
self.scheduler.step() self.scheduler.step()
@ -94,54 +91,59 @@ class CheckpointCallback(TrainCallback):
""" """
Checkpoint callback for trainer. Checkpoint callback for trainer.
""" """
def __init__(self, checkpoint_interval: int): def __init__(self, interval: int, save_dir: str):
self.checkpoint_interval = checkpoint_interval self.interval = interval
self.save_dir = save_dir
self.checkpoint = None
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'): def _save_checkpoint(self, context: 'TrainContext'):
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.batch_iter}") save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}")
context.checkpoint.optimizer_state = context.optimizer.state_dict() self.checkpoint = Checkpoint(
context.checkpoint.scheduler_state = context.scheduler.state_dict() context.optimizer.state_dict(),
context.checkpoint.epoch = context.epoch context.scheduler.state_dict(),
context.checkpoint.batch_iter = context.batch_iter context.epoch,
context.checkpoint.save(save_path) context.iteration
self.last_ckpt_iter = context.batch_iter )
self.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): def on_batch_end(self, context: 'TrainContext'):
context.checkpoint.loss_list.append(context.loss) if context.iteration - self.last_ckpt_iter >= self.interval:
self._save_checkpoint(context)
if context.batch_iter - self.last_ckpt_iter >= self.checkpoint_interval: def on_train_end(self, context: 'TrainContext'):
self._save_checkpoint(trainer, context) if context.iteration != self.last_ckpt_iter:
self._save_checkpoint(context)
def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'): def on_error(self, context: 'TrainContext'):
if context.batch_iter != self.last_ckpt_iter: self._save_checkpoint(context)
self._save_checkpoint(trainer, context)
class ProgressBarCallback(TrainCallback): class ProgressBarCallback(TrainCallback):
""" """
Progress bar callback for trainer. Progress bar callback for trainer.
""" """
def __init__(self): def __init__(self, num_epoch: int):
self.num_epoch = num_epoch
self.progress_bar: tqdm = None self.progress_bar: tqdm = None
def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'): def on_epoch_begin(self, context: 'TrainContext'):
self.progress_bar = tqdm( self.progress_bar = tqdm(
context.dataloader, context.dataloader,
desc=f"Epoch {context.epoch+1}/{trainer.train_config.n_epoch}", desc=f"Epoch {context.epoch+1}/{self.num_epoch}",
dynamic_ncols=True dynamic_ncols=True
) )
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): def on_batch_end(self, context: 'TrainContext'):
_ = trainer
self.progress_bar.set_postfix({ self.progress_bar.set_postfix({
"loss": f"{context.loss:.4f}", "loss": f"{context.loss:.4f}",
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}" "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}"
}) })
self.progress_bar.update(1) self.progress_bar.update(1)
def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'): def on_epoch_end(self, context: 'TrainContext'):
_ = trainer, context _ = context
if self.progress_bar: if self.progress_bar:
self.progress_bar.close() self.progress_bar.close()
@ -177,13 +179,13 @@ class StepMonitorCallback(TrainCallback):
self.log_dir.mkdir(parents=True, exist_ok=True) self.log_dir.mkdir(parents=True, exist_ok=True)
def _handle_info(self, trainer: 'Trainer', context: 'TrainContext'): def _handle_info(self, context: 'TrainContext'):
""" Logs training information to console and file. """ """ Logs training information to console and file. """
log_data = { log_data = {
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'), "timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
"epoch": context.epoch, "epoch": context.epoch,
"iter": context.batch_iter, "iter": context.iteration,
"metrics": self.metrics, "metrics": self.metrics,
} }
@ -193,34 +195,34 @@ class StepMonitorCallback(TrainCallback):
elif metric == 'lr': elif metric == 'lr':
log_data[metric] = context.optimizer.param_groups[-1]['lr'] log_data[metric] = context.optimizer.param_groups[-1]['lr']
elif metric == 'grad_norm': elif metric == 'grad_norm':
log_data[metric] = grad_norm(trainer.parameter.model) log_data[metric] = grad_norm(context.model)
elif metric == 'grad_std': elif metric == 'grad_std':
log_data[metric] = grad_std(trainer.parameter.model) log_data[metric] = grad_std(context.model)
elif metric == 'grad_max': elif metric == 'grad_max':
log_data[metric] = grad_max(trainer.parameter.model) log_data[metric] = grad_max(context.model)
elif metric == 'grad_min': elif metric == 'grad_min':
log_data[metric] = grad_min(trainer.parameter.model) log_data[metric] = grad_min(context.model)
elif metric == 'grad_mean': elif metric == 'grad_mean':
log_data[metric] = grad_mean(trainer.parameter.model) log_data[metric] = grad_mean(context.model)
elif metric == 'grad_nan_num': elif metric == 'grad_nan_num':
log_data[metric] = grad_nan_num(trainer.parameter.model) log_data[metric] = grad_nan_num(context.model)
else: else:
raise ValueError(f"Invalid metric: {metric}") raise ValueError(f"Invalid metric: {metric}")
return log_data return log_data
def _handle_log(self, trainer: 'Trainer', context: 'TrainContext'): def _handle_log(self, context: 'TrainContext'):
""" Logs training information to console and file. """ """ Logs training information to console and file. """
log_data = self._handle_info(trainer, context) log_data = self._handle_info(context)
try: try:
log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.batch_iter}.json" log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.iteration}.json"
with open(log_file, 'a') as f: with open(log_file, 'a') as f:
json.dump(log_data, f, indent=4) json.dump(log_data, f, indent=4)
except Exception: except Exception:
raise raise
def on_step_end(self, trainer: 'Trainer', context: 'TrainContext'): def on_step_end(self, context: 'TrainContext'):
if self.step_num % self.log_interval == 0: if self.step_num % self.log_interval == 0:
self._handle_log(trainer, context) self._handle_log(context)
self.step_num += 1 self.step_num += 1

View File

@ -1,61 +1,60 @@
from dataclasses import dataclass, field, fields import torch.nn as nn
from typing import Optional, Self, TYPE_CHECKING
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from khaosz.config import Checkpoint
from khaosz.data import ResumableDistributedSampler
from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory
from khaosz.parallel.utils import get_world_size, get_rank
if TYPE_CHECKING: from khaosz.data import ResumableDistributedSampler
from khaosz.trainer.trainer import Trainer from khaosz.trainer.checkpoint import Checkpoint
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
from khaosz.config.train_config import TrainConfig
from khaosz.parallel.utils import get_current_device, get_world_size, get_rank
from dataclasses import dataclass, field
from typing import Optional, Self
@dataclass @dataclass
class TrainContext: class TrainContext:
model: nn.Module = field(default=None)
strategy: BaseStrategy = field(default=None)
dataloader: DataLoader = field(default=None) dataloader: DataLoader = field(default=None)
optimizer: Optimizer = field(default=None) optimizer: Optimizer = field(default=None)
scheduler: BaseScheduler = field(default=None) scheduler: LRScheduler = field(default=None)
checkpoint: Checkpoint = field(default=None) checkpoint: Checkpoint = field(default=None)
epoch: int = field(default=0) epoch: int = field(default=0)
batch_iter: int = field(default=0) iteration: int = field(default=0)
loss: float = field(default=0.0) loss: float = field(default=0.0)
wolrd_size: int = field(default=1) wolrd_size: int = field(default=1)
rank: int = field(default=0) rank: int = field(default=0)
def asdict(self) -> dict:
return {field.name: getattr(self, field.name)
for field in fields(self)}
class TrainContextBuilder: class TrainContextBuilder:
def __init__(self, trainer: 'Trainer'): def __init__(self, config: TrainConfig):
self.trainer = trainer self.config = config
self._context: TrainContext = None self._context: TrainContext = None
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._context = TrainContext() self._context = TrainContext()
if checkpoint is None: if checkpoint is None:
checkpoint = Checkpoint( checkpoint = Checkpoint(
model=self.trainer.parameter.model, optimizer_state=self.config.optimizer.state_dict(),
tokenizer=self.trainer.parameter.tokenizer, scheduler_state=self.config.scheduler.state_dict(),
config=self.trainer.parameter.config,
) )
else: else:
# resume from the assigned checkpoint or assigned iteration # resume from the assigned checkpoint or assigned iteration
self._context.epoch = max(checkpoint.epoch, self.trainer.train_config.start_epoch) self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
self._context.batch_iter = max(checkpoint.batch_iter, self.trainer.train_config.start_batch) self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
self._context.checkpoint = checkpoint self._context.checkpoint = checkpoint
return self return self
def with_optimizer(self) -> Self: def with_optimizer(self) -> Self:
optimizer = self.config.optimizer
if self._context is None: if self._context is None:
raise RuntimeError("Must call with_checkpoint() before with_optimizer()") raise RuntimeError("Must call with_checkpoint() before with_optimizer()")
optimizer = self.trainer.train_config.optimizer
if self._context.checkpoint and self._context.checkpoint.optimizer_state: if self._context.checkpoint and self._context.checkpoint.optimizer_state:
optimizer.load_state_dict(self._context.checkpoint.optimizer_state) optimizer.load_state_dict(self._context.checkpoint.optimizer_state)
@ -67,13 +66,7 @@ class TrainContextBuilder:
return self return self
def with_scheduler(self) -> Self: def with_scheduler(self) -> Self:
if not hasattr(self._context, 'optimizer') or self._context.optimizer is None: scheduler = self.config.scheduler
raise RuntimeError("Must call with_optimizer() before with_scheduler()")
optimizer = self.trainer.train_config.optimizer
schedule_config = self.trainer.schedule_config
scheduler = SchedulerFactory.load_scheduler(optimizer, schedule_config)
if self._context.checkpoint and self._context.checkpoint.scheduler_state: if self._context.checkpoint and self._context.checkpoint.scheduler_state:
scheduler.load_state_dict(self._context.checkpoint.scheduler_state) scheduler.load_state_dict(self._context.checkpoint.scheduler_state)
@ -85,29 +78,41 @@ class TrainContextBuilder:
return self return self
def with_dataloader(self) -> Self: def with_dataloader(self) -> Self:
# fix: change batch level batch_iter to sample level offset # fix: change batch level iteration to sample level offset
sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size config = self.config
sampler_offset = self._context.iteration * config.batch_size
resumeable_sampler = ResumableDistributedSampler( resumeable_sampler = ResumableDistributedSampler(
data_source=self.trainer.train_config.dataset, data_source=config.dataset,
start_epoch=self._context.epoch, start_epoch=self._context.epoch,
start_iter=sampler_offset, start_iter=sampler_offset,
seed=self.trainer.train_config.random_seed seed=config.random_seed
) )
dataloader = DataLoader( dataloader = DataLoader(
self.trainer.train_config.dataset, config.dataset,
batch_size=self.trainer.train_config.batch_size, batch_size=config.batch_size,
sampler=resumeable_sampler, sampler=resumeable_sampler,
num_workers=self.trainer.train_config.num_workers, num_workers=config.num_workers,
pin_memory=self.trainer.train_config.pin_memory, pin_memory=config.pin_memory,
prefetch_factor=self.trainer.train_config.prefetch_factor prefetch_factor=config.prefetch_factor
) )
self._context.dataloader = dataloader self._context.dataloader = dataloader
return self return self
def build(self) -> TrainContext: def with_strategy(self) -> Self:
device = get_current_device()
self._context.strategy = StrategyFactory.load(
model=self.config.model,
train_type=self.config.strategy,
device=device,
**self.config.kwargs
)
return self
if self.trainer.train_config.nprocs > 1: def build(self) -> TrainContext:
self._context.model = self.config.model
if self.config.nprocs > 1:
self._context.wolrd_size = get_world_size() self._context.wolrd_size = get_world_size()
self._context.rank = get_rank() self._context.rank = get_rank()

View File

@ -1,11 +1,6 @@
import logging import logging
from typing import Optional, List from typing import Optional, List
from khaosz.config import ( from khaosz.config import TrainConfig
ModelParameter,
Checkpoint,
ScheduleConfig,
TrainConfig
)
from khaosz.trainer.train_callback import ( from khaosz.trainer.train_callback import (
TrainCallback, TrainCallback,
ProgressBarCallback, ProgressBarCallback,
@ -13,7 +8,7 @@ from khaosz.trainer.train_callback import (
GradientClippingCallback, GradientClippingCallback,
SchedulerCallback SchedulerCallback
) )
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder from khaosz.trainer.train_context import TrainContext, TrainContextBuilder, Checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -21,66 +16,65 @@ logger = logging.getLogger(__name__)
class Trainer: class Trainer:
def __init__( def __init__(
self, self,
parameter: ModelParameter,
train_config: TrainConfig, train_config: TrainConfig,
schedule_config: ScheduleConfig,
callbacks: Optional[List[TrainCallback]] = None callbacks: Optional[List[TrainCallback]] = None
): ):
self.parameter = parameter
self.train_config = train_config self.train_config = train_config
self.schedule_config = schedule_config
self.callbacks = callbacks or self._get_default_callbacks() self.callbacks = callbacks or self._get_default_callbacks()
def _get_default_callbacks(self) -> List[TrainCallback]: def _get_default_callbacks(self) -> List[TrainCallback]:
train_config = self.train_config
return [ return [
ProgressBarCallback(), ProgressBarCallback(train_config.n_epoch),
CheckpointCallback(self.train_config.checkpoint_interval), CheckpointCallback(train_config.checkpoint_interval, train_config.checkpoint_dir),
GradientClippingCallback(self.train_config.max_grad_norm), GradientClippingCallback(train_config.max_grad_norm),
SchedulerCallback(self.schedule_config), SchedulerCallback(train_config.scheduler),
] ]
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (TrainContextBuilder(self) return (TrainContextBuilder(self.train_config)
.with_checkpoint(checkpoint) .with_checkpoint(checkpoint)
.with_optimizer() .with_optimizer()
.with_scheduler() .with_scheduler()
.with_dataloader() .with_dataloader()
.with_strategy()
.build()) .build())
def _call_callbacks(self, method_name: str, context: TrainContext): def _call_callbacks(self, method_name: str, context: TrainContext):
for callback in self.callbacks: for callback in self.callbacks:
method = getattr(callback, method_name, None) method = getattr(callback, method_name, None)
if method: if method:
method(self, context) method(context)
def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint: def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
context = self._build_context(checkpoint) context = self._build_context(checkpoint)
self._call_callbacks('on_train_begin', context) self._call_callbacks('on_train_begin', context)
try: try:
self.parameter.model.train() context.model.train()
# 1.epoch # 1.epoch
for epoch in range(context.epoch, self.train_config.n_epoch): for epoch in range(context.epoch, self.train_config.n_epoch):
context.epoch = epoch context.epoch = epoch
self._call_callbacks('on_epoch_begin', context) self._call_callbacks('on_epoch_begin', context)
for batch in context.dataloader: for batch in context.dataloader:
if context.batch_iter % self.train_config.accumulation_steps == 0: if context.iteration % self.train_config.accumulation_steps == 0:
# 2. step # 2. step
self._call_callbacks('on_step_begin', context) self._call_callbacks('on_step_begin', context)
self.train_config.optimizer.step() context.optimizer.step()
self.train_config.optimizer.zero_grad() context.optimizer.zero_grad()
self._call_callbacks('on_step_end', context) self._call_callbacks('on_step_end', context)
# 3. batch # 3. batch
self._call_callbacks('on_batch_begin', context) self._call_callbacks('on_batch_begin', context)
loss = self.train_config.strategy(batch) loss = context.strategy(batch)
context.loss = loss.item() context.loss = loss.item()
context.batch_iter += 1 context.iteration += 1
# to make the loss normalized by accumulation steps # to make the loss normalized by accumulation steps
normalized_loss = loss / self.train_config.accumulation_steps stand_batch = self.train_config.accumulation_steps * self.train_config.nprocs
normalized_loss.backward() stand_loss = loss / stand_batch
stand_loss.backward()
self._call_callbacks('on_batch_end', context) self._call_callbacks('on_batch_end', context)

View File

@ -5,10 +5,20 @@ from khaosz.trainer import *
def test_callback_integration(base_test_env, random_dataset): def test_callback_integration(base_test_env, random_dataset):
"""Test that all callbacks are properly integrated""" """Test that all callbacks are properly integrated"""
schedule_config = CosineScheduleConfig(
warmup_steps=10,
total_steps=20
)
optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
scheduler = SchedulerFactory.load(optimizer, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
model=base_test_env["model"],
strategy='seq',
dataset=random_dataset, dataset=random_dataset,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler,
checkpoint_dir=base_test_env["test_dir"], checkpoint_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=2, batch_size=2,
@ -18,36 +28,26 @@ def test_callback_integration(base_test_env, random_dataset):
random_seed=42 random_seed=42
) )
schedule_config = CosineScheduleConfig(
warmup_steps=10,
total_steps=20
)
# Create custom callbacks to track calls # Create custom callbacks to track calls
callback_calls = [] callback_calls = []
class TrackingCallback(TrainCallback): class TrackingCallback(TrainCallback):
def on_train_begin(self, trainer, context): def on_train_begin(self, context):
callback_calls.append('on_train_begin') callback_calls.append('on_train_begin')
def on_batch_end(self, trainer, context): def on_batch_end(self, context):
callback_calls.append('on_batch_end') callback_calls.append('on_batch_end')
def on_epoch_end(self, trainer, context): def on_epoch_end(self, context):
callback_calls.append('on_epoch_end') callback_calls.append('on_epoch_end')
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
model_parameter = ModelParameter(
base_test_env["model"],
base_test_env["tokenizer"],
base_test_env["transformer_config"]
)
trainer = Trainer( trainer = Trainer(
model_parameter,
train_config, train_config,
schedule_config, callbacks=[TrackingCallback()]
callbacks=[TrackingCallback(), ProgressBarCallback()]
) )
trainer.train() trainer.train()

View File

@ -7,35 +7,34 @@ from khaosz.trainer import *
def test_early_stopping_simulation(base_test_env, early_stopping_dataset): def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
"""Simulate early stopping behavior""" """Simulate early stopping behavior"""
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
scheduler = SchedulerFactory.load(optimizer, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
strategy="seq",
scheduler=scheduler,
model=base_test_env["model"],
dataset=early_stopping_dataset, dataset=early_stopping_dataset,
optimizer=optimizer, optimizer=optimizer,
checkpoint_dir=base_test_env["test_dir"], checkpoint_dir=base_test_env["test_dir"],
n_epoch=2, n_epoch=2,
batch_size=2, batch_size=2,
checkpoint_interval=2, checkpoint_interval=1,
accumulation_steps=2, accumulation_steps=2,
random_seed=np.random.randint(1e4), random_seed=np.random.randint(1e4),
) )
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"]) trainer = Trainer(train_config)
model_parameter = ModelParameter(
base_test_env["model"],
base_test_env["tokenizer"],
base_test_env["transformer_config"]
)
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
trainer = Trainer(model_parameter, train_config, schedule_config)
# Should handle early stopping gracefully # Should handle early stopping gracefully
checkpoint = None checkpoint = None
try: try:
checkpoint = trainer.train() checkpoint = trainer.train()
assert len(checkpoint.loss_list) == 2 assert checkpoint.iteration == 2
except Exception: except Exception:
# Handle any exceptions # Handle any exceptions
pass pass
checkpoint = trainer.train(checkpoint) checkpoint = trainer.train(checkpoint)
assert len(checkpoint.loss_list) == 10 assert checkpoint.iteration == 10

View File

@ -51,13 +51,6 @@ def test_env(request: pytest.FixtureRequest):
shutil.rmtree(test_dir) shutil.rmtree(test_dir)
# parameter loader
def test_parameter_loader(test_env):
loaded_param = ParameterLoader.load(test_env["test_dir"])
assert loaded_param.model is not None
assert loaded_param.tokenizer is not None
assert loaded_param.config == test_env["transformer_config"]
def test_model_parameter(test_env): def test_model_parameter(test_env):
save_dir = os.path.join(test_env["test_dir"], "save") save_dir = os.path.join(test_env["test_dir"], "save")
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"]) model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])

View File

@ -31,10 +31,18 @@ def test_gradient_accumulation(base_test_env, random_dataset):
accumulation_steps_list = [1, 2, 4] accumulation_steps_list = [1, 2, 4]
for accumulation_steps in accumulation_steps_list: for accumulation_steps in accumulation_steps_list:
schedule_config = CosineScheduleConfig(
warmup_steps=10,
total_steps=20
)
optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
scheduler = SchedulerFactory.load(optimizer, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
dataset=random_dataset, strategy="seq",
model=base_test_env["model"],
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler,
dataset=random_dataset,
checkpoint_dir=base_test_env["test_dir"], checkpoint_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=2, batch_size=2,
@ -44,18 +52,7 @@ def test_gradient_accumulation(base_test_env, random_dataset):
random_seed=42 random_seed=42
) )
schedule_config = CosineScheduleConfig( trainer = Trainer(train_config)
warmup_steps=10,
total_steps=20
)
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])
model_parameter = ModelParameter(
base_test_env["model"],
base_test_env["tokenizer"],
base_test_env["transformer_config"]
)
trainer = Trainer(model_parameter, train_config, schedule_config)
trainer.train() trainer.train()
assert train_config.accumulation_steps == accumulation_steps assert train_config.accumulation_steps == accumulation_steps

View File

@ -35,7 +35,7 @@ def test_schedule_factory_random_configs():
config.validate() config.validate()
# Create scheduler using factory # Create scheduler using factory
scheduler = SchedulerFactory.load_scheduler(optimizer, config) scheduler = SchedulerFactory.load(optimizer, config)
# Verify scheduler type # Verify scheduler type
if isinstance(config, CosineScheduleConfig): if isinstance(config, CosineScheduleConfig):
@ -83,7 +83,7 @@ def test_schedule_factory_edge_cases():
for config in edge_cases: for config in edge_cases:
config.validate() config.validate()
scheduler = SchedulerFactory.load_scheduler(optimizer, config) scheduler = SchedulerFactory.load(optimizer, config)
assert scheduler is not None assert scheduler is not None
# Test multiple steps # Test multiple steps
@ -97,16 +97,17 @@ def test_schedule_factory_invalid_configs():
# Test invalid configurations that should raise errors # Test invalid configurations that should raise errors
invalid_configs = [ invalid_configs = [
# Negative warmup steps # Negative warmup steps
CosineScheduleConfig(warmup_steps=-10, total_steps=1000, min_rate=0.1), {"warmup_steps": -10, "total_steps": 1000, "min_rate": 0.1},
# Total steps less than warmup steps # Total steps less than warmup steps
CosineScheduleConfig(warmup_steps=500, total_steps=400, min_rate=0.1), {"warmup_steps": 500, "total_steps": 400, "min_rate": 0.1},
# Invalid min_rate # Invalid min_rate
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=-0.1), {"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1},
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=1.1), {"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1},
] ]
for config in invalid_configs: for kwargs in invalid_configs:
with pytest.raises(ValueError): with pytest.raises(ValueError):
config = CosineScheduleConfig(**kwargs)
config.validate() config.validate()
@ -117,7 +118,7 @@ def test_schedule_factory_state_persistence():
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1) config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1)
scheduler = SchedulerFactory.load_scheduler(optimizer, config) scheduler = SchedulerFactory.load(optimizer, config)
# Take a few steps # Take a few steps
for _ in range(5): for _ in range(5):
@ -127,7 +128,7 @@ def test_schedule_factory_state_persistence():
state_dict = scheduler.state_dict() state_dict = scheduler.state_dict()
# Create new scheduler and load state # Create new scheduler and load state
new_scheduler = SchedulerFactory.load_scheduler(optimizer, config) new_scheduler = SchedulerFactory.load(optimizer, config)
new_scheduler.load_state_dict(state_dict) new_scheduler.load_state_dict(state_dict)
# Verify states match # Verify states match

View File

@ -3,8 +3,8 @@ import argparse
import torch import torch
from torch.optim import AdamW from torch.optim import AdamW
from khaosz.config import ParameterLoader, Checkpoint, TrainConfig, CosineScheduleConfig from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
from khaosz.trainer import Trainer, StrategyFactory from khaosz.trainer import Trainer, SchedulerFactory
from khaosz.data import DatasetLoader from khaosz.data import DatasetLoader
@ -36,7 +36,6 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.") parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.") parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.") parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
parser.add_argument("--resume_from_checkpoint", action="store_true", help="Train from checkpoint or not.")
args = parser.parse_args() args = parser.parse_args()
@ -66,16 +65,12 @@ def train(
pin_memory: bool, pin_memory: bool,
window_size: int, window_size: int,
stride: int, stride: int,
resume_from_checkpoint: bool
): ):
assert train_type in ["seq", "sft", "dpo"] assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path) assert os.path.exists(param_path)
parameter = ParameterLoader.load(param_path) parameter = ModelParameter()
checkpoint = None parameter.load(param_path)
if isinstance(parameter, Checkpoint) and resume_from_checkpoint:
checkpoint = parameter
if window_size is None: if window_size is None:
window_size = parameter.config.m_len window_size = parameter.config.m_len
@ -91,13 +86,6 @@ def train(
"pad_token_id": parameter.tokenizer.pad_id, "pad_token_id": parameter.tokenizer.pad_id,
} }
strategy = StrategyFactory.load(
model,
train_type,
device,
**kwargs
)
dataset = DatasetLoader.load( dataset = DatasetLoader.load(
train_type=train_type, train_type=train_type,
load_path=data_root_path, load_path=data_root_path,
@ -111,16 +99,25 @@ def train(
{"params": [p for n, p in model.named_parameters() if "embed" not in n], "lr": max_lr} {"params": [p for n, p in model.named_parameters() if "embed" not in n], "lr": max_lr}
] ]
optim = AdamW( optimizer = AdamW(
param_groups, param_groups,
betas=(adamw_beta1, adamw_beta2), betas=(adamw_beta1, adamw_beta2),
weight_decay=adamw_weight_decay weight_decay=adamw_weight_decay
) )
schedule_config = CosineScheduleConfig(
warmup_steps=warmup_steps,
total_steps=len(dataset) * n_epoch // batch_size,
)
scheduler = SchedulerFactory.load(optimizer, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
strategy=strategy, model=model,
strategy=train_type,
dataset=dataset, dataset=dataset,
optimizer=optim, optimizer=optimizer,
scheduler=scheduler,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
n_epoch=n_epoch, n_epoch=n_epoch,
batch_size=batch_size, batch_size=batch_size,
@ -134,17 +131,8 @@ def train(
pin_memory=pin_memory pin_memory=pin_memory
) )
schedule_config = CosineScheduleConfig( trainer = Trainer(train_config)
warmup_steps=warmup_steps, trainer.train()
total_steps=len(dataset) * n_epoch // batch_size,
)
trainer = Trainer(
parameter=parameter,
train_config=train_config,
schedule_config=schedule_config,
)
trainer.train(checkpoint)
if __name__ == "__main__": if __name__ == "__main__":