style: 使用ruff 工具优化代码风格

This commit is contained in:
ViperEkura 2026-03-30 23:32:28 +08:00
parent 345fd2f091
commit 426af2d75f
52 changed files with 1838 additions and 1495 deletions

27
.github/workflows/lint.yml vendored Normal file
View File

@ -0,0 +1,27 @@
name: Lint
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install dependencies
run: |
pip install --upgrade pip
pip install .[dev]
- name: Check formatting with ruff
run: |
ruff format --check .

View File

@ -1,17 +0,0 @@
name: Spell Check
on: [push, pull_request]
permissions:
contents: read
jobs:
spellcheck:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Check spelling in specific files
uses: codespell-project/actions-codespell@v2
with:
check_filenames: true
only_warn: false
path: "**/*.{md, py}"

31
.github/workflows/tests.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: Tests
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install --upgrade pip
pip install .[dev]
- name: Run tests with pytest
run: |
python -m pytest tests/ -v

5
.gitignore vendored
View File

@ -8,5 +8,8 @@
!*.py !*.py
!*.md !*.md
!*.png !*.png
!LICENSE !LICENSE
!pyproject.toml !pyproject.toml
!.github/workflows/lint.yml
!.github/workflows/tests.yml

View File

@ -54,8 +54,8 @@ python train.py \
--n_epoch=5 \ --n_epoch=5 \
--batch_size=8 \ --batch_size=8 \
--max_lr=2e-4 \ --max_lr=2e-4 \
--checkpoint_interval=10000 \ --ckpt_interval=10000 \
--checkpoint_dir=checkpoints --ckpt_dir=checkpoints
``` ```
**Parameter Explanation:** **Parameter Explanation:**
@ -67,8 +67,8 @@ python train.py \
- `--accumulation_steps`: Number of batches per training step - `--accumulation_steps`: Number of batches per training step
- `--warmup_steps`: Warmup steps - `--warmup_steps`: Warmup steps
- `--max_lr`: Maximum learning rate (using warmup + cosine decay) - `--max_lr`: Maximum learning rate (using warmup + cosine decay)
- `--checkpoint_interval`: Checkpoint saving interval - `--ckpt_interval`: Checkpoint saving interval
- `--checkpoint_dir`: Checkpoint saving directory - `--ckpt_dir`: Checkpoint saving directory
- `--resume_dir`: Resume training from specified path - `--resume_dir`: Resume training from specified path
@ -191,8 +191,8 @@ python train.py \
--n_epoch=5 \ --n_epoch=5 \
--batch_size=8 \ --batch_size=8 \
--max_lr=2e-4 \ --max_lr=2e-4 \
--checkpoint_interval=10000 \ --ckpt_interval=10000 \
--checkpoint_dir=checkpoints --ckpt_dir=checkpoints
``` ```
**参数说明:** **参数说明:**
@ -204,8 +204,8 @@ python train.py \
- `--accumulation_steps`: 每个训练步骤的 batch 数量 - `--accumulation_steps`: 每个训练步骤的 batch 数量
- `--warmup_steps`: 预热步数warmup steps - `--warmup_steps`: 预热步数warmup steps
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减) - `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
- `--checkpoint_interval`: 检查点保存间隔 - `--ckpt_interval`: 检查点保存间隔
- `--checkpoint_dir`: 检查点保存目录 - `--ckpt_dir`: 检查点保存目录
- `--resume_dir`: 从指定路径恢复训练 - `--resume_dir`: 从指定路径恢复训练

View File

@ -2,13 +2,12 @@ import os
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
PROJECT_ROOT = os.path.dirname( PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.path.dirname(os.path.abspath(__file__)))
if __name__ == "__main__": if __name__ == "__main__":
snapshot_download( snapshot_download(
repo_id="ViperEk/KHAOSZ", repo_id="ViperEk/KHAOSZ",
local_dir=os.path.join(PROJECT_ROOT, "params"), local_dir=os.path.join(PROJECT_ROOT, "params"),
force_download=True force_download=True,
) )

View File

@ -5,18 +5,18 @@ from khaosz.inference.core import disable_random_init
from khaosz.inference.generator import LoopGenerator, GenerationRequest from khaosz.inference.generator import LoopGenerator, GenerationRequest
PROJECT_ROOT = os.path.dirname( PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.path.dirname(os.path.abspath(__file__)))
def generate_text(): def generate_text():
with disable_random_init(): with disable_random_init():
model_dir = os.path.join(PROJECT_ROOT, "params") model_dir = os.path.join(PROJECT_ROOT, "params")
param = ModelParameter.load(model_dir) param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16) param.to(device="cuda", dtype=torch.bfloat16)
query = input(">> ") query = input(">> ")
request = GenerationRequest( request = GenerationRequest(
query=query, query=query,
temperature=0.8, temperature=0.8,
@ -28,8 +28,9 @@ def generate_text():
) )
generator = LoopGenerator(param) generator = LoopGenerator(param)
response = generator.generate(request) response = generator.generate(request)
print(response) print(response)
if __name__ == "__main__": if __name__ == "__main__":
generate_text() generate_text()

View File

@ -4,18 +4,24 @@ from khaosz.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init from khaosz.inference.core import disable_random_init
from khaosz.inference.generator import BatchGenerator, GenerationRequest from khaosz.inference.generator import BatchGenerator, GenerationRequest
PROJECT_ROOT = os.path.dirname( PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.path.dirname(os.path.abspath(__file__)))
def batch_generate(): def batch_generate():
with disable_random_init(): with disable_random_init():
model_dir = os.path.join(PROJECT_ROOT, "params") model_dir = os.path.join(PROJECT_ROOT, "params")
param = ModelParameter.load(model_dir) param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16) param.to(device="cuda", dtype=torch.bfloat16)
generator = BatchGenerator(param) generator = BatchGenerator(param)
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"] inputs = [
"你好",
"请问什么是人工智能",
"今天天气如何",
"我感到焦虑, 请问我应该怎么办",
"请问什么是显卡",
]
request = GenerationRequest( request = GenerationRequest(
query=inputs, query=inputs,
temperature=0.8, temperature=0.8,
@ -26,9 +32,10 @@ def batch_generate():
system_prompt=None, system_prompt=None,
) )
responses = generator.generate(request) responses = generator.generate(request)
for q, r in zip(inputs, responses): for q, r in zip(inputs, responses):
print((q, r)) print((q, r))
if __name__ == "__main__": if __name__ == "__main__":
batch_generate() batch_generate()

View File

@ -5,16 +5,16 @@ from khaosz.inference.core import disable_random_init
from khaosz.inference.generator import StreamGenerator, GenerationRequest from khaosz.inference.generator import StreamGenerator, GenerationRequest
PROJECT_ROOT = os.path.dirname( PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.path.dirname(os.path.abspath(__file__)))
def chat(): def chat():
with disable_random_init(): with disable_random_init():
model_dir = os.path.join(PROJECT_ROOT, "params") model_dir = os.path.join(PROJECT_ROOT, "params")
param = ModelParameter.load(model_dir) param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16) param.to(device="cuda", dtype=torch.bfloat16)
generator = StreamGenerator(param) generator = StreamGenerator(param)
history = [] history = []
@ -22,7 +22,7 @@ def chat():
query = input(">> ") query = input(">> ")
if query == "!exit": if query == "!exit":
break break
request = GenerationRequest( request = GenerationRequest(
query=query, query=query,
temperature=0.8, temperature=0.8,
@ -32,7 +32,7 @@ def chat():
history=history, history=history,
system_prompt=None, system_prompt=None,
) )
response_size = 0 response_size = 0
full_response = "" full_response = ""
for response in generator.generate(request): for response in generator.generate(request):
@ -40,10 +40,10 @@ def chat():
print(response[response_size:], end="", flush=True) print(response[response_size:], end="", flush=True)
response_size = len(response) response_size = len(response)
full_response = response full_response = response
# After generation, update history # After generation, update history
history.append((query, full_response.strip())) history.append((query, full_response.strip()))
if __name__ == "__main__": if __name__ == "__main__":
chat() chat()

View File

@ -6,41 +6,30 @@ from khaosz.config import (
TrainConfig, TrainConfig,
) )
from khaosz.model.transformer import Transformer from khaosz.model.transformer import Transformer
from khaosz.data import ( from khaosz.data import DatasetLoader, BpeTokenizer
DatasetLoader,
BpeTokenizer
)
from khaosz.inference.generator import ( from khaosz.inference.generator import (
GenerationRequest, GenerationRequest,
LoopGenerator, LoopGenerator,
StreamGenerator, StreamGenerator,
BatchGenerator, BatchGenerator,
EmbeddingEncoder, EmbeddingEncoder,
GeneratorFactory GeneratorFactory,
)
from khaosz.trainer import (
Trainer,
StrategyFactory,
SchedulerFactory
) )
from khaosz.trainer import Trainer, StrategyFactory, SchedulerFactory
__all__ = [ __all__ = [
"Transformer", "Transformer",
"ModelConfig", "ModelConfig",
"TrainConfig", "TrainConfig",
"DatasetLoader", "DatasetLoader",
"BpeTokenizer", "BpeTokenizer",
"GenerationRequest", "GenerationRequest",
"LoopGenerator", "LoopGenerator",
"StreamGenerator", "StreamGenerator",
"BatchGenerator", "BatchGenerator",
"EmbeddingEncoder", "EmbeddingEncoder",
"GeneratorFactory", "GeneratorFactory",
"Trainer", "Trainer",
"StrategyFactory", "StrategyFactory",
"SchedulerFactory" "SchedulerFactory",
] ]

View File

@ -1,10 +1,10 @@
from khaosz.config.model_config import ModelConfig from khaosz.config.model_config import ModelConfig
from khaosz.config.param_config import BaseModelIO, ModelParameter from khaosz.config.param_config import BaseModelIO, ModelParameter
from khaosz.config.schedule_config import ( from khaosz.config.schedule_config import (
ScheduleConfig, ScheduleConfig,
CosineScheduleConfig, CosineScheduleConfig,
SGDRScheduleConfig, SGDRScheduleConfig,
ScheduleConfigFactory ScheduleConfigFactory,
) )
from khaosz.config.train_config import TrainConfig from khaosz.config.train_config import TrainConfig
@ -13,14 +13,12 @@ __all__ = [
# Base I/O # Base I/O
"BaseModelIO", "BaseModelIO",
"ModelParameter", "ModelParameter",
# Model configuration # Model configuration
"ModelConfig", "ModelConfig",
"TrainConfig", "TrainConfig",
# Schedule configuration # Schedule configuration
"ScheduleConfig", "ScheduleConfig",
"CosineScheduleConfig", "CosineScheduleConfig",
"SGDRScheduleConfig", "SGDRScheduleConfig",
"ScheduleConfigFactory", "ScheduleConfigFactory",
] ]

View File

@ -14,30 +14,29 @@ class ModelConfig:
norm_eps: Optional[float] = None norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None dim_ffn: Optional[int] = None
tie_weight: Optional[bool] = None tie_weight: Optional[bool] = None
# RoPE # RoPE
max_len: Optional[int] = None max_len: Optional[int] = None
rope_theta: Optional[float] = None rope_theta: Optional[float] = None
# GQA # GQA
n_heads: Optional[int] = None n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None use_gated_attention: Optional[bool] = None
def load(self, config_path: str) -> Self: def load(self, config_path: str) -> Self:
config = {} config = {}
with open(config_path, 'r') as f: with open(config_path, "r") as f:
config.update(json.load(f)) config.update(json.load(f))
for key, value in config.items(): for key, value in config.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
return self return self
def save(self, config_path: str): def save(self, config_path: str):
config_dict = {k: v for k, v in asdict(self).items() if v is not None} config_dict = {k: v for k, v in asdict(self).items() if v is not None}
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4) json.dump(config_dict, f, indent=4)

View File

@ -9,58 +9,57 @@ from khaosz.data.tokenizer import BpeTokenizer
from khaosz.config.model_config import ModelConfig from khaosz.config.model_config import ModelConfig
from khaosz.model.transformer import Transformer from khaosz.model.transformer import Transformer
@dataclass @dataclass
class BaseModelIO: class BaseModelIO:
"""Base class for model I/O operations.""" """Base class for model I/O operations."""
model: Optional[nn.Module] = field( model: Optional[nn.Module] = field(
default=None, default=None, metadata={"help": "Transformer model."}
metadata={"help": "Transformer model."}
) )
tokenizer: BpeTokenizer = field( tokenizer: BpeTokenizer = field(
default_factory=BpeTokenizer, default_factory=BpeTokenizer, metadata={"help": "Tokenizer for the model."}
metadata={"help": "Tokenizer for the model."}
) )
config: ModelConfig = field( config: ModelConfig = field(
default_factory=ModelConfig, default_factory=ModelConfig,
metadata={"help": "Transformer model configuration."} metadata={"help": "Transformer model configuration."},
) )
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]: def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
"""Get standardized file paths for model components.""" """Get standardized file paths for model components."""
dir_path = Path(directory) dir_path = Path(directory)
return { return {
"model": dir_path / "model.safetensors", "model": dir_path / "model.safetensors",
"config": dir_path / "config.json", "config": dir_path / "config.json",
"tokenizer": dir_path / "tokenizer.json" "tokenizer": dir_path / "tokenizer.json",
} }
def save_components(self, save_dir: Union[str, Path]): def save_components(self, save_dir: Union[str, Path]):
"""Save core model components.""" """Save core model components."""
paths = self._get_file_paths(save_dir) paths = self._get_file_paths(save_dir)
paths["model"].parent.mkdir(parents=True, exist_ok=True) paths["model"].parent.mkdir(parents=True, exist_ok=True)
if self.model is not None: if self.model is not None:
st.save_file(self.model.state_dict(), str(paths["model"])) st.save_file(self.model.state_dict(), str(paths["model"]))
self.config.save(str(paths["config"])) self.config.save(str(paths["config"]))
self.tokenizer.save(str(paths["tokenizer"])) self.tokenizer.save(str(paths["tokenizer"]))
def load_components(self, load_dir: Union[str, Path]) -> Self: def load_components(self, load_dir: Union[str, Path]) -> Self:
"""Load core model components.""" """Load core model components."""
paths = self._get_file_paths(load_dir) paths = self._get_file_paths(load_dir)
self.config.load(str(paths["config"])) self.config.load(str(paths["config"]))
self.tokenizer.load(str(paths["tokenizer"])) self.tokenizer.load(str(paths["tokenizer"]))
if self.model is None: if self.model is None:
self.model = Transformer(self.config) self.model = Transformer(self.config)
if paths["model"].exists(): if paths["model"].exists():
state_dict = st.load_file(str(paths["model"])) state_dict = st.load_file(str(paths["model"]))
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
return self return self
def to(self, *args, **kwargs) -> "BaseModelIO": def to(self, *args, **kwargs) -> "BaseModelIO":
"""Move model to device.""" """Move model to device."""
if self.model is not None: if self.model is not None:
@ -71,13 +70,12 @@ class BaseModelIO:
@dataclass @dataclass
class ModelParameter(BaseModelIO): class ModelParameter(BaseModelIO):
"""Container for model parameters with serialization capabilities.""" """Container for model parameters with serialization capabilities."""
@classmethod @classmethod
def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]): def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]):
instance.save_components(save_dir) instance.save_components(save_dir)
@classmethod @classmethod
def load(cls, load_dir: Union[str, Path]) -> "ModelParameter": def load(cls, load_dir: Union[str, Path]) -> "ModelParameter":
instance = cls() instance = cls()
return instance.load_components(load_dir) return instance.load_components(load_dir)

View File

@ -6,35 +6,35 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class ScheduleConfig(ABC): class ScheduleConfig(ABC):
"""Base configuration class for learning rate schedulers. """Base configuration class for learning rate schedulers.
Provides common validation and interface for all schedule types. Provides common validation and interface for all schedule types.
""" """
schedule_type: str = field( schedule_type: str = field(
default="cosine", default="cosine",
metadata={ metadata={
"help": "Type of learning rate schedule.", "help": "Type of learning rate schedule.",
"choices": ["cosine", "sgdr"] "choices": ["cosine", "sgdr"],
} },
) )
warmup_steps: int = field( warmup_steps: int = field(
default=1000, default=1000, metadata={"help": "Number of warmup steps."}
metadata={"help": "Number of warmup steps."}
) )
min_rate: float = field( min_rate: float = field(
default=0.05, default=0.05, metadata={"help": "Minimum learning rate multiplier."}
metadata={"help": "Minimum learning rate multiplier."}
) )
@abstractmethod @abstractmethod
def get_kwargs(self) -> Dict[str, Any]: def get_kwargs(self) -> Dict[str, Any]:
"""Get configuration kwargs for scheduler creation.""" """Get configuration kwargs for scheduler creation."""
raise NotImplementedError raise NotImplementedError
def validate(self) -> None: def validate(self) -> None:
"""Validate configuration parameters.""" """Validate configuration parameters."""
if self.warmup_steps < 0: if self.warmup_steps < 0:
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}") raise ValueError(
f"warmup_steps must be non-negative, got {self.warmup_steps}"
)
if not 0 <= self.min_rate <= 1: if not 0 <= self.min_rate <= 1:
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}") raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
@ -42,44 +42,43 @@ class ScheduleConfig(ABC):
@dataclass @dataclass
class CosineScheduleConfig(ScheduleConfig): class CosineScheduleConfig(ScheduleConfig):
"""Cosine annealing learning rate schedule configuration.""" """Cosine annealing learning rate schedule configuration."""
total_steps: int = field( total_steps: int = field(
default=None, default=None, metadata={"help": "Total training steps for cosine schedule."}
metadata={"help": "Total training steps for cosine schedule."}
) )
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.schedule_type = "cosine" self.schedule_type = "cosine"
self.validate() self.validate()
def get_kwargs(self) -> Dict[str, Any]: def get_kwargs(self) -> Dict[str, Any]:
if self.total_steps is None: if self.total_steps is None:
raise ValueError("total_steps must be specified for cosine schedule") raise ValueError("total_steps must be specified for cosine schedule")
return { return {
"schedule_type": self.schedule_type, "schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps, "warmup_steps": self.warmup_steps,
"lr_decay_steps": self.total_steps - self.warmup_steps, "lr_decay_steps": self.total_steps - self.warmup_steps,
"min_rate": self.min_rate "min_rate": self.min_rate,
} }
def validate(self) -> None: def validate(self) -> None:
super().validate() super().validate()
if self.total_steps is not None and self.total_steps <= self.warmup_steps: if self.total_steps is not None and self.total_steps <= self.warmup_steps:
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})") raise ValueError(
f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})"
)
@dataclass @dataclass
class SGDRScheduleConfig(ScheduleConfig): class SGDRScheduleConfig(ScheduleConfig):
"""Stochastic Gradient Descent with Warm Restarts schedule configuration.""" """Stochastic Gradient Descent with Warm Restarts schedule configuration."""
cycle_length: int = field( cycle_length: int = field(
default=1000, default=1000, metadata={"help": "Length of the first cycle in steps."}
metadata={"help": "Length of the first cycle in steps."}
) )
t_mult: int = field( t_mult: int = field(
default=2, default=2, metadata={"help": "Multiplier for cycle length growth."}
metadata={"help": "Multiplier for cycle length growth."}
) )
def __post_init__(self) -> None: def __post_init__(self) -> None:
@ -92,9 +91,9 @@ class SGDRScheduleConfig(ScheduleConfig):
"warmup_steps": self.warmup_steps, "warmup_steps": self.warmup_steps,
"cycle_length": self.cycle_length, "cycle_length": self.cycle_length,
"min_rate": self.min_rate, "min_rate": self.min_rate,
"t_mult": self.t_mult "t_mult": self.t_mult,
} }
def validate(self) -> None: def validate(self) -> None:
super().validate() super().validate()
if self.cycle_length <= 0: if self.cycle_length <= 0:
@ -105,33 +104,33 @@ class SGDRScheduleConfig(ScheduleConfig):
class ScheduleConfigFactory: class ScheduleConfigFactory:
"""Factory class for creating ScheduleConfig instances. """Factory class for creating ScheduleConfig instances.
Supports both direct instantiation and factory creation methods. Supports both direct instantiation and factory creation methods.
Example usage: Example usage:
# Direct creation # Direct creation
config = CosineScheduleConfig(total_steps=10000) config = CosineScheduleConfig(total_steps=10000)
# Factory method # Factory method
config = ScheduleConfigFactory.create("cosine", total_steps=10000) config = ScheduleConfigFactory.create("cosine", total_steps=10000)
""" """
CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = { CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = {
"cosine": CosineScheduleConfig, "cosine": CosineScheduleConfig,
"sgdr": SGDRScheduleConfig, "sgdr": SGDRScheduleConfig,
} }
@classmethod @classmethod
def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig: def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig:
"""Create a schedule config instance. """Create a schedule config instance.
Args: Args:
schedule_type: Type of schedule ("cosine", "sgdr") schedule_type: Type of schedule ("cosine", "sgdr")
**kwargs: Arguments passed to the config constructor **kwargs: Arguments passed to the config constructor
Returns: Returns:
ScheduleConfig instance ScheduleConfig instance
Raises: Raises:
ValueError: If schedule_type is not supported ValueError: If schedule_type is not supported
""" """
@ -140,11 +139,11 @@ class ScheduleConfigFactory:
f"Unknown schedule type: '{schedule_type}'. " f"Unknown schedule type: '{schedule_type}'. "
f"Supported types: {sorted(cls.CONFIG_MAP.keys())}" f"Supported types: {sorted(cls.CONFIG_MAP.keys())}"
) )
config_cls = cls.CONFIG_MAP[schedule_type] config_cls = cls.CONFIG_MAP[schedule_type]
return config_cls(**kwargs) return config_cls(**kwargs)
@classmethod @classmethod
def available_types(cls) -> list: def available_types(cls) -> list:
"""Return list of available schedule type names.""" """Return list of available schedule type names."""
return list(cls.CONFIG_MAP.keys()) return list(cls.CONFIG_MAP.keys())

View File

@ -10,127 +10,92 @@ from typing import Callable, List, Optional
@dataclass @dataclass
class TrainConfig: class TrainConfig:
# basic setting # basic setting
model: nn.Module = field( model: nn.Module = field(default=None, metadata={"help": "Model for training."})
default=None, strategy: str = field(default=None, metadata={"help": "Training strategy."})
metadata={"help": "Model for training."} dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."})
)
strategy: str = field(
default=None,
metadata={"help": "Training strategy."}
)
dataset: Dataset = field(
default=None,
metadata={"help": "Dataset for training."}
)
optimizer_fn: Callable[[nn.Module], Optimizer] = field( optimizer_fn: Callable[[nn.Module], Optimizer] = field(
default=None, default=None, metadata={"help": "Optimizer factory for training."}
metadata={"help": "Optimizer factory for training."}
) )
scheduler_fn: Callable[[Optimizer], LRScheduler] = field( scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
default=None, default=None, metadata={"help": "Scheduler factory for training."}
metadata={"help": "Scheduler factory for training."}
)
n_epoch: int = field(
default=1,
metadata={"help": "Number of epochs for training."}
)
batch_size: int = field(
default=4,
metadata={"help": "Batch size for training."}
) )
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
accumulation_steps: int = field( accumulation_steps: int = field(
default=1, default=1, metadata={"help": "Number of iterations between steps."}
metadata={"help": "Number of iterations between steps."}
) )
max_grad_norm: float = field( max_grad_norm: float = field(
default=1.0, default=1.0, metadata={"help": "Maximum gradient norm."}
metadata={"help": "Maximum gradient norm."}
) )
# checkpoint setting # checkpoint setting
start_epoch: int = field( start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
default=0,
metadata={"help": "Start epoch for training."}
)
start_batch: int = field( start_batch: int = field(
default=0, default=0, metadata={"help": "Start batch iteration for training."}
metadata={"help": "Start batch iteration for training."}
) )
checkpoint_dir: str = field( ckpt_dir: str = field(
default="./checkpoint", default="./checkpoint", metadata={"help": "Checkpoint directory."}
metadata={"help": "Checkpoint directory."}
) )
checkpoint_interval: int = field( ckpt_interval: int = field(
default=5000, default=5000, metadata={"help": "Number of iterations between checkpoints."}
metadata={"help": "Number of iterations between checkpoints."}
) )
# dataloader setting # dataloader setting
random_seed: int = field( random_seed: int = field(default=3407, metadata={"help": "Random seed."})
default=3407,
metadata={"help": "Random seed."}
)
num_workers: int = field( num_workers: int = field(
default=0, default=0, metadata={"help": "Number of workers for dataloader."}
metadata={"help": "Number of workers for dataloader."}
) )
prefetch_factor: Optional[int] = field( prefetch_factor: Optional[int] = field(
default=None, default=None, metadata={"help": "Prefetch factor for dataloader."}
metadata={"help": "Prefetch factor for dataloader."}
) )
pin_memory: bool = field( pin_memory: bool = field(
default=False, default=False, metadata={"help": "Pin memory for dataloader."}
metadata={"help": "Pin memory for dataloader."}
) )
# distributed training # distributed training
nprocs: int = field( nprocs: int = field(
default=1, default=1, metadata={"help": "Number of processes for distributed training."}
metadata={"help": "Number of processes for distributed training."}
) )
backend: str = field( backend: str = field(
default="nccl", default="nccl", metadata={"help": "Distributed training backend."}
metadata={"help": "Distributed training backend."}
) )
master_addr: str = field( master_addr: str = field(
default="localhost", default="localhost",
metadata={"help": "Master address for distributed training."} metadata={"help": "Master address for distributed training."},
) )
master_port: str = field( master_port: str = field(
default="29500", default="29500", metadata={"help": "Master port for distributed training."}
metadata={"help": "Master port for distributed training."}
) )
parallel_wrapper: Optional[Callable] = field( parallel_wrapper: Optional[Callable] = field(
default=None, default=None, metadata={"help": "Parallel function for training."}
metadata={"help": "Parallel function for training."}
) )
state_dict_fn: Optional[Callable] = field( state_dict_fn: Optional[Callable] = field(
default=None, default=None, metadata={"help": "Parallel function for state dict saving."}
metadata={"help": "Parallel function for state dict saving."}
) )
# others # others
device_ids: Optional[List[int]] = field( device_ids: Optional[List[int]] = field(
default=None, default=None, metadata={"help": "Device ids for distributed training."}
metadata={"help": "Device ids for distributed training."}
) )
device_type: str = field( device_type: str = field(
default="cuda", default="cuda", metadata={"help": "Device type for distributed training."}
metadata={"help": "Device type for distributed training."}
) )
extra_kwargs: dict = field( extra_kwargs: dict = field(
default_factory=dict, default_factory=dict, metadata={"help": "Other arguments."}
metadata={"help": "Other arguments."}
) )
def __post_init__(self): def __post_init__(self):
self.validate() self.validate()
def validate(self): def validate(self):
required_fields = ["model", "strategy", "dataset", "optimizer_fn", "scheduler_fn"] required_fields = [
"model",
"strategy",
"dataset",
"optimizer_fn",
"scheduler_fn",
]
for field_name in required_fields: for field_name in required_fields:
if getattr(self, field_name) is None: if getattr(self, field_name) is None:
raise ValueError(f"{field_name} is required.") raise ValueError(f"{field_name} is required.")

View File

@ -1,12 +1,12 @@
from khaosz.data.dataset import ( from khaosz.data.dataset import (
BaseDataset, BaseDataset,
SEQDataset, SEQDataset,
DPODataset, DPODataset,
SFTDataset, SFTDataset,
GRPODataset, GRPODataset,
MultiSegmentFetcher, MultiSegmentFetcher,
DatasetLoader, DatasetLoader,
DatasetFactory DatasetFactory,
) )
from khaosz.data.tokenizer import BpeTokenizer from khaosz.data.tokenizer import BpeTokenizer
@ -15,21 +15,17 @@ from khaosz.data.sampler import ResumableDistributedSampler
__all__ = [ __all__ = [
# Base classes # Base classes
"BaseDataset", "BaseDataset",
# Dataset implementations # Dataset implementations
"SEQDataset", "SEQDataset",
"SFTDataset", "SFTDataset",
"DPODataset", "DPODataset",
"GRPODataset", "GRPODataset",
# Fetchers # Fetchers
"MultiSegmentFetcher", "MultiSegmentFetcher",
# Factory (DatasetLoader is alias for backward compatibility) # Factory (DatasetLoader is alias for backward compatibility)
"DatasetLoader", "DatasetLoader",
"DatasetFactory", "DatasetFactory",
# Tokenizer and sampler # Tokenizer and sampler
"BpeTokenizer", "BpeTokenizer",
"ResumableDistributedSampler" "ResumableDistributedSampler",
] ]

View File

@ -12,40 +12,42 @@ from typing import Callable, List, Dict, Literal, Optional, Union
class BaseSegmentFetcher: class BaseSegmentFetcher:
"""Fetches data segments across multiple tensor segments. """Fetches data segments across multiple tensor segments.
Maintains cumulative lengths for efficient range queries across Maintains cumulative lengths for efficient range queries across
multiple discontinuous segments. multiple discontinuous segments.
""" """
def __init__(self, segments: List[Tensor]): def __init__(self, segments: List[Tensor]):
self.segments = segments self.segments = segments
self.cum_lengths = [] self.cum_lengths = []
total = 0 total = 0
for seg in segments: for seg in segments:
total += torch.numel(seg) total += torch.numel(seg)
self.cum_lengths.append(total) self.cum_lengths.append(total)
self.total_length = total self.total_length = total
def __len__(self) -> int: def __len__(self) -> int:
return self.total_length return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx). """Fetch data in the range [begin_idx, end_idx).
Args: Args:
begin_idx: Starting index (inclusive) begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive) end_idx: Ending index (exclusive)
Returns: Returns:
Concatenated tensor of data in the specified range Concatenated tensor of data in the specified range
""" """
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length): if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds") raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx: if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long) return torch.tensor([], dtype=torch.long)
# Find segment boundaries for the range # Find segment boundaries for the range
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx) seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx) seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
@ -64,43 +66,44 @@ class BaseSegmentFetcher:
class MultiSegmentFetcher: class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys. """Manages multiple segment fetchers for different data keys.
Each key corresponds to a different type of data (e.g., "sequence", "mask"). Each key corresponds to a different type of data (e.g., "sequence", "mask").
""" """
def __init__(self, muti_segments: Dict): def __init__(self, muti_segments: Dict):
self.muti_keys = list(muti_segments.keys()) self.muti_keys = list(muti_segments.keys())
self.muti_fetchers = { self.muti_fetchers = {
key: BaseSegmentFetcher(segments) key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items()
for key, segments in muti_segments.items()
} }
def __len__(self) -> int: def __len__(self) -> int:
"""Returns the minimum length across all fetchers.""" """Returns the minimum length across all fetchers."""
len_list = [len(seg) for seg in self.muti_fetchers.values()] len_list = [len(seg) for seg in self.muti_fetchers.values()]
return min(len_list) return min(len_list)
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict: def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys. """Fetch data for specific keys.
Args: Args:
begin_idx: Starting index begin_idx: Starting index
end_idx: Ending index end_idx: Ending index
keys: Single key or list of keys to fetch keys: Single key or list of keys to fetch
Returns: Returns:
Dictionary of tensors if multiple keys, single tensor if one key Dictionary of tensors if multiple keys, single tensor if one key
""" """
fetch_dict = {} fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys keys = [keys] if isinstance(keys, str) else keys
for key in keys: for key in keys:
fetcher = self.muti_fetchers[key] fetcher = self.muti_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx) fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]] return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict: def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys.""" """Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.muti_keys) return self.key_fetch(begin_idx, end_idx, self.muti_keys)
@ -108,10 +111,10 @@ class MultiSegmentFetcher:
class BaseDataset(Dataset, ABC): class BaseDataset(Dataset, ABC):
"""Abstract base class for all dataset types. """Abstract base class for all dataset types.
Implements common functionality for window-based data fetching. Implements common functionality for window-based data fetching.
""" """
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__() super().__init__()
self.segments = {} self.segments = {}
@ -122,38 +125,38 @@ class BaseDataset(Dataset, ABC):
def load(self, load_path: str): def load(self, load_path: str):
"""Load dataset from HDF5 file. """Load dataset from HDF5 file.
Args: Args:
load_path: Path to the HDF5 data file load_path: Path to the HDF5 data file
""" """
self.segments = load_h5(load_path) self.segments = load_h5(load_path)
self.fetcher = MultiSegmentFetcher(self.segments) self.fetcher = MultiSegmentFetcher(self.segments)
self.total_samples = len(self.fetcher) self.total_samples = len(self.fetcher)
def get_index(self, index: int) -> tuple: def get_index(self, index: int) -> tuple:
"""Calculate begin and end indices for a sample. """Calculate begin and end indices for a sample.
Args: Args:
index: Sample index index: Sample index
Returns: Returns:
Tuple of (begin_idx, end_idx) Tuple of (begin_idx, end_idx)
""" """
assert self.total_samples > self.window_size assert self.total_samples > self.window_size
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size) begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
end_idx = min(begin_idx + self.window_size, self.total_samples - 1) end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
return begin_idx, end_idx return begin_idx, end_idx
@abstractmethod @abstractmethod
def __getitem__(self, index: int) -> Dict[str, Tensor]: def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Get a single sample by index. """Get a single sample by index.
Must be implemented by subclasses. Must be implemented by subclasses.
""" """
raise NotImplementedError raise NotImplementedError
def __len__(self) -> int: def __len__(self) -> int:
assert self.total_samples is not None assert self.total_samples is not None
if self.total_samples <= self.window_size: if self.total_samples <= self.window_size:
@ -163,48 +166,50 @@ class BaseDataset(Dataset, ABC):
class DatasetFactory: class DatasetFactory:
"""Factory class for creating dataset instances. """Factory class for creating dataset instances.
Supports decorator-based registration for extensible dataset types. Supports decorator-based registration for extensible dataset types.
All default dataset types (seq, sft, dpo, grpo) are registered automatically All default dataset types (seq, sft, dpo, grpo) are registered automatically
when their classes are defined with the decorator. when their classes are defined with the decorator.
Example usage: Example usage:
@DatasetFactory.register("custom") @DatasetFactory.register("custom")
class CustomDataset(BaseDataset): class CustomDataset(BaseDataset):
... ...
dataset = DatasetFactory.create("custom", window_size, stride) dataset = DatasetFactory.create("custom", window_size, stride)
""" """
SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"}) SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"})
DATASET_MAP: Dict[str, type] = {} DATASET_MAP: Dict[str, type] = {}
@classmethod @classmethod
def register(cls, name: str): def register(cls, name: str):
"""Decorator to register a new dataset class. """Decorator to register a new dataset class.
Args: Args:
name: Registration name for the dataset type name: Registration name for the dataset type
Returns: Returns:
Decorator function that registers the dataset class Decorator function that registers the dataset class
""" """
def decorator(dataset_cls: type) -> type: def decorator(dataset_cls: type) -> type:
if not issubclass(dataset_cls, BaseDataset): if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset") raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
cls.DATASET_MAP[name] = dataset_cls cls.DATASET_MAP[name] = dataset_cls
return dataset_cls return dataset_cls
return decorator return decorator
@classmethod @classmethod
def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset: def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset:
"""Create a dataset instance. """Create a dataset instance.
Args: Args:
train_type: Type of training ("seq", "sft", "dpo", "grpo") train_type: Type of training ("seq", "sft", "dpo", "grpo")
window_size: Window size for data sampling window_size: Window size for data sampling
stride: Stride between consecutive samples stride: Stride between consecutive samples
Returns: Returns:
Dataset instance Dataset instance
""" """
@ -213,36 +218,42 @@ class DatasetFactory:
f"Unknown dataset type: '{train_type}'. " f"Unknown dataset type: '{train_type}'. "
f"Supported types: {sorted(cls.SUPPORTED_TYPES)}" f"Supported types: {sorted(cls.SUPPORTED_TYPES)}"
) )
if train_type not in cls.DATASET_MAP: if train_type not in cls.DATASET_MAP:
raise NotImplementedError( raise NotImplementedError(
f"Dataset type '{train_type}' is supported but not yet implemented." f"Dataset type '{train_type}' is supported but not yet implemented."
) )
dataset_cls = cls.DATASET_MAP[train_type] dataset_cls = cls.DATASET_MAP[train_type]
return dataset_cls(window_size, stride) return dataset_cls(window_size, stride)
@classmethod @classmethod
def load(cls, train_type: str, load_path: str, window_size: int, stride: Optional[int] = None) -> BaseDataset: def load(
cls,
train_type: str,
load_path: str,
window_size: int,
stride: Optional[int] = None,
) -> BaseDataset:
"""Create and load a dataset in one step. """Create and load a dataset in one step.
Args: Args:
train_type: Type of training dataset train_type: Type of training dataset
load_path: Path to the data file load_path: Path to the data file
window_size: Window size for data sampling window_size: Window size for data sampling
stride: Stride between consecutive samples (default: same as window_size) stride: Stride between consecutive samples (default: same as window_size)
Returns: Returns:
Loaded dataset instance Loaded dataset instance
""" """
if stride is None: if stride is None:
stride = window_size stride = window_size
dataset = cls.create(train_type, window_size, stride) dataset = cls.create(train_type, window_size, stride)
dataset.load(load_path) dataset.load(load_path)
return dataset return dataset
@classmethod @classmethod
def available_types(cls) -> list: def available_types(cls) -> list:
"""Return list of registered dataset type names.""" """Return list of registered dataset type names."""
@ -256,46 +267,50 @@ class DatasetFactory:
@DatasetFactory.register("seq") @DatasetFactory.register("seq")
class SEQDataset(BaseDataset): class SEQDataset(BaseDataset):
"""Dataset for sequential next-token prediction training.""" """Dataset for sequential next-token prediction training."""
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
def __getitem__(self, index): def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
return {"input_ids": x, "target_ids": y} return {"input_ids": x, "target_ids": y}
@DatasetFactory.register("sft") @DatasetFactory.register("sft")
class SFTDataset(BaseDataset): class SFTDataset(BaseDataset):
"""Dataset for supervised fine-tuning with loss masking.""" """Dataset for supervised fine-tuning with loss masking."""
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index): def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool) dtype=torch.long
)
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
dtype=torch.bool
)
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask} return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
@DatasetFactory.register("dpo") @DatasetFactory.register("dpo")
class DPODataset(BaseDataset): class DPODataset(BaseDataset):
"""Dataset for Direct Preference Optimization training.""" """Dataset for Direct Preference Optimization training."""
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
@ -304,25 +319,34 @@ class DPODataset(BaseDataset):
def __getitem__(self, index: int): def __getitem__(self, index: int):
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long) chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long) rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool) chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool) dtype=torch.bool
)
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(
dtype=torch.bool
)
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask} return {
"chosen": chosen,
"rejected": rejected,
"chosen_mask": chosen_mask,
"rejected_mask": rejected_mask,
}
@DatasetFactory.register("grpo") @DatasetFactory.register("grpo")
class GRPODataset(BaseDataset): class GRPODataset(BaseDataset):
"""Dataset for Group Relative Policy Optimization training.""" """Dataset for Group Relative Policy Optimization training."""
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int) -> Dict[str, Tensor]: def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
@ -330,8 +354,13 @@ class GRPODataset(BaseDataset):
responses = self._fetch_data(begin_idx, end_idx, "responses") responses = self._fetch_data(begin_idx, end_idx, "responses")
masks = self._fetch_data(begin_idx, end_idx, "masks") masks = self._fetch_data(begin_idx, end_idx, "masks")
rewards = self._fetch_data(begin_idx, end_idx, "rewards") rewards = self._fetch_data(begin_idx, end_idx, "rewards")
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards} return {
"prompts": prompts,
"responses": responses,
"masks": masks,
"rewards": rewards,
}
# Backward compatibility alias # Backward compatibility alias

View File

@ -7,45 +7,45 @@ from typing import Optional
class ResumableDistributedSampler(Sampler[int]): class ResumableDistributedSampler(Sampler[int]):
def __init__( def __init__(
self, self,
data_source: Dataset, data_source: Dataset,
start_epoch: int=0, start_epoch: int = 0,
start_iter: int=0, start_iter: int = 0,
seed: int=42, seed: int = 42,
drop_last: bool=False, drop_last: bool = False,
shuffle: bool=True, shuffle: bool = True,
process_group: Optional[dist.ProcessGroup]=None, process_group: Optional[dist.ProcessGroup] = None,
): ):
self.epoch = start_epoch self.epoch = start_epoch
self.iter = start_iter self.iter = start_iter
self.seed = seed self.seed = seed
self.num_samples = len(data_source) self.num_samples = len(data_source)
if process_group is not None: if process_group is not None:
# input process group # input process group
self.rank = dist.get_rank(process_group) self.rank = dist.get_rank(process_group)
self.num_replicas = dist.get_world_size(process_group) self.num_replicas = dist.get_world_size(process_group)
elif dist.is_available() and dist.is_initialized(): elif dist.is_available() and dist.is_initialized():
# use default process group # use default process group
process_group = dist.group.WORLD process_group = dist.group.WORLD
self.rank = dist.get_rank() self.rank = dist.get_rank()
self.num_replicas = dist.get_world_size() self.num_replicas = dist.get_world_size()
else: else:
# single process # single process
self.rank = 0 self.rank = 0
self.num_replicas = 1 self.num_replicas = 1
self.drop_last = drop_last self.drop_last = drop_last
self.shuffle = shuffle self.shuffle = shuffle
offset = 0 if drop_last else self.num_replicas - 1 offset = 0 if drop_last else self.num_replicas - 1
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
self.total_size = self.num_samples_per_replica * self.num_replicas self.total_size = self.num_samples_per_replica * self.num_replicas
self._indices = None self._indices = None
def _get_indices(self): def _get_indices(self):
if self.shuffle: if self.shuffle:
generator = torch.Generator() generator = torch.Generator()
@ -53,26 +53,26 @@ class ResumableDistributedSampler(Sampler[int]):
indices = torch.randperm(self.num_samples, generator=generator).tolist() indices = torch.randperm(self.num_samples, generator=generator).tolist()
else: else:
indices = torch.arange(self.num_samples).tolist() indices = torch.arange(self.num_samples).tolist()
if not self.drop_last and self.num_samples < self.total_size: if not self.drop_last and self.num_samples < self.total_size:
padding_size = self.total_size - len(indices) padding_size = self.total_size - len(indices)
indices += indices[:padding_size] indices += indices[:padding_size]
local_indices = indices[self.rank:self.total_size:self.num_replicas] local_indices = indices[self.rank : self.total_size : self.num_replicas]
self.iter = self.iter % self.num_samples_per_replica self.iter = self.iter % self.num_samples_per_replica
self._indices = local_indices[self.iter:] self._indices = local_indices[self.iter :]
def __iter__(self): def __iter__(self):
if self._indices is None: if self._indices is None:
self._get_indices() self._get_indices()
for i in self._indices: for i in self._indices:
self.iter += 1 self.iter += 1
yield i yield i
self.epoch += 1 self.epoch += 1
self._indices = None self._indices = None
def __len__(self): def __len__(self):
return self.num_samples_per_replica return self.num_samples_per_replica

View File

@ -10,24 +10,26 @@ from torch import Tensor
from typing import Any, Dict, List from typing import Any, Dict, List
from khaosz.parallel.setup import get_rank from khaosz.parallel.setup import get_rank
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True) os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5") full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, 'w') as f: with h5py.File(full_file_path, "w") as f:
for key, tensors in tensor_group.items(): for key, tensors in tensor_group.items():
grp = f.create_group(key) grp = f.create_group(key)
for idx, tensor in enumerate(tensors): for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy() arr = tensor.cpu().numpy()
grp.create_dataset(f'data_{idx}', data=arr) grp.create_dataset(f"data_{idx}", data=arr)
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]: def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {} tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path) root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5")) h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files: for h5_file in h5_files:
with h5py.File(h5_file, 'r') as f: with h5py.File(h5_file, "r") as f:
for key in f.keys(): for key in f.keys():
grp = f[key] grp = f[key]
dsets = [] dsets = []
@ -37,7 +39,7 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
if share_memory: if share_memory:
tensor = tensor.share_memory_() tensor = tensor.share_memory_()
dsets.append(tensor) dsets.append(tensor)
if tensor_group.get(key) is None: if tensor_group.get(key) is None:
tensor_group[key] = [] tensor_group[key] = []
tensor_group[key].extend(dsets) tensor_group[key].extend(dsets)
@ -60,7 +62,7 @@ class Checkpoint:
self, self,
save_dir: str, save_dir: str,
) -> None: ) -> None:
save_path = Path(save_dir) save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True) save_path.mkdir(parents=True, exist_ok=True)
@ -72,7 +74,7 @@ class Checkpoint:
} }
with open(save_path / "meta.json", "w") as f: with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2) json.dump(meta, f, indent=2)
st.save_file(self.state_dict, save_path / f"state_dict.safetensors") st.save_file(self.state_dict, save_path / f"state_dict.safetensors")
@classmethod @classmethod
@ -83,7 +85,7 @@ class Checkpoint:
rank = get_rank() rank = get_rank()
save_path = Path(save_dir) save_path = Path(save_dir)
meta = {} meta = {}
if rank == 0: if rank == 0:
with open(Path(save_dir) / "meta.json", "r") as f: with open(Path(save_dir) / "meta.json", "r") as f:
@ -100,4 +102,4 @@ class Checkpoint:
state_dict=state_dict, state_dict=state_dict,
epoch=meta["epoch"], epoch=meta["epoch"],
iteration=meta["iteration"], iteration=meta["iteration"],
) )

View File

@ -9,34 +9,46 @@ class BpeTokenizer:
def __init__(self, path=None): def __init__(self, path=None):
self._control_tokens = ["<bos>", "<eos>", "<pad>"] self._control_tokens = ["<bos>", "<eos>", "<pad>"]
self._special_tokens = ["<|im_start|>", "<|im_end|>"] self._special_tokens = ["<|im_start|>", "<|im_end|>"]
model = BPE() model = BPE()
self._tokenizer = Tokenizer(model) self._tokenizer = Tokenizer(model)
self._tokenizer.normalizer = normalizers.Sequence([ self._tokenizer.normalizer = normalizers.Sequence(
normalizers.NFC(), [normalizers.NFC(), normalizers.Strip()]
normalizers.Strip() )
])
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence([ [
pre_tokenizers.UnicodeScripts(), pre_tokenizers.UnicodeScripts(),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True) pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True),
]) ]
)
self._tokenizer.decoder = decoders.ByteLevel() self._tokenizer.decoder = decoders.ByteLevel()
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
if path is not None: if path is not None:
self._tokenizer = Tokenizer.from_file(path) self._tokenizer = Tokenizer.from_file(path)
def _prepare_trainer(self, vocab_size: int, min_freq: int, reserved_token_size: int, max_token_length=18) -> tuple: def _prepare_trainer(
self,
vocab_size: int,
min_freq: int,
reserved_token_size: int,
max_token_length=18,
) -> tuple:
assert reserved_token_size > len(self._special_tokens) assert reserved_token_size > len(self._special_tokens)
reserved_tokens = [f"<|reserve{i:02d}|>" for i in range(reserved_token_size - len(self._special_tokens))] reserved_tokens = [
detail_vocab_size = vocab_size - (len(reserved_tokens) + len(self._special_tokens)) f"<|reserve{i:02d}|>"
for i in range(reserved_token_size - len(self._special_tokens))
]
detail_vocab_size = vocab_size - (
len(reserved_tokens) + len(self._special_tokens)
)
alphabet = pre_tokenizers.ByteLevel.alphabet() alphabet = pre_tokenizers.ByteLevel.alphabet()
min_size = len(alphabet) + len(self._control_tokens) min_size = len(alphabet) + len(self._control_tokens)
assert detail_vocab_size > min_size assert detail_vocab_size > min_size
trainer = BpeTrainer( trainer = BpeTrainer(
vocab_size=detail_vocab_size, vocab_size=detail_vocab_size,
min_frequency=min_freq, min_frequency=min_freq,
@ -46,61 +58,74 @@ class BpeTokenizer:
initial_alphabet=alphabet, initial_alphabet=alphabet,
show_progress=True, show_progress=True,
) )
return trainer, detail_vocab_size, reserved_tokens return trainer, detail_vocab_size, reserved_tokens
def train(self, files, vocab_size, min_freq, reserved_token_size=100): def train(self, files, vocab_size, min_freq, reserved_token_size=100):
trainer, _, reserved_tokens = self._prepare_trainer( trainer, _, reserved_tokens = self._prepare_trainer(
vocab_size=vocab_size, vocab_size=vocab_size,
min_freq=min_freq, min_freq=min_freq,
reserved_token_size=reserved_token_size reserved_token_size=reserved_token_size,
) )
self._tokenizer.train(files=files, trainer=trainer) self._tokenizer.train(files=files, trainer=trainer)
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens) self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
def train_from_iterator(self, iterator, vocab_size, min_freq, reserved_token_size=100): def train_from_iterator(
self, iterator, vocab_size, min_freq, reserved_token_size=100
):
trainer, _, reserved_tokens = self._prepare_trainer( trainer, _, reserved_tokens = self._prepare_trainer(
vocab_size=vocab_size, vocab_size=vocab_size,
min_freq=min_freq, min_freq=min_freq,
reserved_token_size=reserved_token_size reserved_token_size=reserved_token_size,
) )
self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer) self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer)
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens) self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
def save(self, path): def save(self, path):
self._tokenizer.save(path) self._tokenizer.save(path)
def load(self, path): def load(self, path):
self._tokenizer = Tokenizer.from_file(path) self._tokenizer = Tokenizer.from_file(path)
def encode(self, tokens: Union[str, List[str]], out_ids: bool=True, add_special_tokens: bool=False) -> List: def encode(
self,
tokens: Union[str, List[str]],
out_ids: bool = True,
add_special_tokens: bool = False,
) -> List:
if isinstance(tokens, str): if isinstance(tokens, str):
encoded: Encoding = self._tokenizer.encode(tokens, add_special_tokens=add_special_tokens) encoded: Encoding = self._tokenizer.encode(
tokens, add_special_tokens=add_special_tokens
)
return encoded.ids if out_ids else encoded.tokens return encoded.ids if out_ids else encoded.tokens
elif isinstance(tokens, list): elif isinstance(tokens, list):
encoded_list: List[Encoding] = self._tokenizer.encode_batch(tokens, add_special_tokens=add_special_tokens) encoded_list: List[Encoding] = self._tokenizer.encode_batch(
return [encoded.ids if out_ids else encoded.tokens for encoded in encoded_list] tokens, add_special_tokens=add_special_tokens
)
return [
encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
]
def decode(self, tokens: List[int], skip_special_tokens: bool=True) -> str: def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def __len__(self) -> int: def __len__(self) -> int:
return self._tokenizer.get_vocab_size() return self._tokenizer.get_vocab_size()
@property @property
def stop_ids(self) -> List[int]: def stop_ids(self) -> List[int]:
stop_token = self._control_tokens + self._special_tokens stop_token = self._control_tokens + self._special_tokens
stop_ids = [self._tokenizer.token_to_id(token) for token in stop_token] stop_ids = [self._tokenizer.token_to_id(token) for token in stop_token]
return stop_ids return stop_ids
@property @property
def bos_id(self) -> int: def bos_id(self) -> int:
return self._tokenizer.token_to_id("<bos>") return self._tokenizer.token_to_id("<bos>")
@property @property
def eos_id(self) -> int: def eos_id(self) -> int:
return self._tokenizer.token_to_id("<eos>") return self._tokenizer.token_to_id("<eos>")
@property @property
def pad_id(self) -> int: def pad_id(self) -> int:
return self._tokenizer.token_to_id("<pad>") return self._tokenizer.token_to_id("<pad>")

View File

@ -11,7 +11,7 @@ from khaosz.inference.generator import (
StreamGenerator, StreamGenerator,
BatchGenerator, BatchGenerator,
EmbeddingEncoder, EmbeddingEncoder,
GeneratorFactory GeneratorFactory,
) )
__all__ = [ __all__ = [
@ -19,11 +19,10 @@ __all__ = [
"GeneratorCore", "GeneratorCore",
"EmbeddingEncoderCore", "EmbeddingEncoderCore",
"KVCacheManager", "KVCacheManager",
"GenerationRequest", "GenerationRequest",
"LoopGenerator", "LoopGenerator",
"StreamGenerator", "StreamGenerator",
"BatchGenerator", "BatchGenerator",
"EmbeddingEncoder", "EmbeddingEncoder",
"GeneratorFactory" "GeneratorFactory",
] ]

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, List, Tuple, Union, Optional, Self from typing import Any, Callable, List, Tuple, Union, Optional, Self
from khaosz.config import ModelParameter, ModelConfig from khaosz.config import ModelParameter, ModelConfig
@ -12,58 +12,61 @@ def apply_sampling_strategies(
temperature: float, temperature: float,
top_k: int, top_k: int,
top_p: float, top_p: float,
filter_value: float = -float("inf") filter_value: float = -float("inf"),
) -> Tensor: ) -> Tensor:
""" """
Apply sampling strategies to the logits tensor. Apply sampling strategies to the logits tensor.
Args: Args:
logits (Tensor): The logits tensor. logits (Tensor): The logits tensor.
temperature (float): The temperature parameter. temperature (float): The temperature parameter.
top_k (int): The top-k parameter. top_k (int): The top-k parameter.
top_p (float): The top-p parameter. top_p (float): The top-p parameter.
filter_value (float, optional): The filter value. Defaults to -float("inf"). filter_value (float, optional): The filter value. Defaults to -float("inf").
Returns: Returns:
Tensor: The sampled logits tensor. Tensor: The sampled logits tensor.
""" """
if temperature != 1.0: if temperature != 1.0:
logits = logits / temperature logits = logits / temperature
if top_k > 0: if top_k > 0:
top_k = min(top_k, logits.size(-1)) top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None] indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
if top_p < 1.0: if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0 sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_( indices_to_remove.scatter_(
dim=1, dim=1, index=sorted_indices, src=sorted_indices_to_remove
index=sorted_indices,
src=sorted_indices_to_remove
) )
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
return logits return logits
@contextmanager @contextmanager
def disable_random_init(): def disable_random_init():
init_functions = [ init_functions = [
'xavier_normal_', 'xavier_uniform_', "xavier_normal_",
'kaiming_normal_', 'kaiming_uniform_', "xavier_uniform_",
'zeros_', 'ones_', 'constant_', "kaiming_normal_",
'normal_', 'uniform_' "kaiming_uniform_",
"zeros_",
"ones_",
"constant_",
"normal_",
"uniform_",
] ]
original_funcs = {} original_funcs = {}
for name in init_functions: for name in init_functions:
@ -82,7 +85,7 @@ class GeneratorCore:
self.model = parameter.model self.model = parameter.model
self.tokenizer = parameter.tokenizer self.tokenizer = parameter.tokenizer
self.config = parameter.config self.config = parameter.config
def generate_iterator( def generate_iterator(
self, self,
input_ids: Tensor, input_ids: Tensor,
@ -91,18 +94,18 @@ class GeneratorCore:
top_p: float, top_p: float,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None, kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0 start_pos: int = 0,
)-> Tuple[Tensor, int]: ) -> Tuple[Tensor, int]:
with torch.inference_mode(): with torch.inference_mode():
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos) outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
logits = outputs["logits"][:, -1, :] logits = outputs["logits"][:, -1, :]
cache_increase = input_ids.size(-1) cache_increase = input_ids.size(-1)
logits = apply_sampling_strategies(logits, temperature, top_k, top_p) logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1) probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1) next_token_id = torch.multinomial(probs, num_samples=1)
return next_token_id, cache_increase return next_token_id, cache_increase
def generate_loop( def generate_loop(
@ -115,14 +118,21 @@ class GeneratorCore:
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None, kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0, start_pos: int = 0,
callback: Optional[Callable[..., Any]] = None callback: Optional[Callable[..., Any]] = None,
) -> List[int]: ) -> List[int]:
cur_cache_pos = start_pos cur_cache_pos = start_pos
for _ in range(len(ids), self.config.max_len): for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator( next_token_id, cache_increase = self.generate_iterator(
input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos) input_ids,
temperature,
top_k,
top_p,
attn_mask,
kv_caches,
cur_cache_pos,
)
input_ids = next_token_id input_ids = next_token_id
ids.append(next_token_id.item()) ids.append(next_token_id.item())
cur_cache_pos += cache_increase cur_cache_pos += cache_increase
@ -132,9 +142,9 @@ class GeneratorCore:
if next_token_id.item() in self.tokenizer.stop_ids: if next_token_id.item() in self.tokenizer.stop_ids:
break break
return ids return ids
def to(self, *args, **kargs) -> Self: def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs) self.model.to(*args, **kargs)
return self return self
@ -145,32 +155,35 @@ class EmbeddingEncoderCore:
self.model = parameter.model self.model = parameter.model
self.tokenizer = parameter.tokenizer self.tokenizer = parameter.tokenizer
self.config = parameter.config self.config = parameter.config
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]: def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
with_batch = isinstance(sentence, list) with_batch = isinstance(sentence, list)
ids = self.tokenizer.encode(sentence) ids = self.tokenizer.encode(sentence)
batch_ids = ids if with_batch else [ids] batch_ids = ids if with_batch else [ids]
max_model_len = self.config.max_len max_model_len = self.config.max_len
all_fragments = [] all_fragments = []
fragment_origin_idx = [] fragment_origin_idx = []
for i, seq in enumerate(batch_ids): for i, seq in enumerate(batch_ids):
if len(seq) > max_model_len: if len(seq) > max_model_len:
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)] fragments = [
seq[j : j + max_model_len]
for j in range(0, len(seq), max_model_len)
]
all_fragments.extend(fragments) all_fragments.extend(fragments)
fragment_origin_idx.extend([i] * len(fragments)) fragment_origin_idx.extend([i] * len(fragments))
else: else:
all_fragments.append(seq) all_fragments.append(seq)
fragment_origin_idx.append(i) fragment_origin_idx.append(i)
#if empty fragments # if empty fragments
if not all_fragments or not ids: if not all_fragments or not ids:
return [] if with_batch else torch.tensor([]) return [] if with_batch else torch.tensor([])
device = next(self.model.parameters()).device device = next(self.model.parameters()).device
max_len = min(max(len(seq) for seq in all_fragments), max_model_len) max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
padded_ids = [] padded_ids = []
masks = [] masks = []
for seq in all_fragments: for seq in all_fragments:
@ -179,24 +192,30 @@ class EmbeddingEncoderCore:
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq] mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
padded_ids.append(padded_seq) padded_ids.append(padded_seq)
masks.append(mask) masks.append(mask)
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long) input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool) seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
with torch.inference_mode(): with torch.inference_mode():
outputs = self.model(input_tensor, seq_mask)["hidden_states"] outputs = self.model(input_tensor, seq_mask)["hidden_states"]
# [num_fragments, seq_len, hidden_size] # [num_fragments, seq_len, hidden_size]
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1)) fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
sentence_embs: List[Tensor] = [] sentence_embs: List[Tensor] = []
for i in range(len(batch_ids)): for i in range(len(batch_ids)):
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i] indices = [
idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i
]
if indices: if indices:
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size] sum_frags = torch.sum(
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1] fragment_embs[indices, :, :], dim=1
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size] ) # [frags, hidden_size]
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(
1
) # [frags, 1]
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
sentence_embs.append(emb.flatten()) sentence_embs.append(emb.flatten())
if with_batch: if with_batch:
return [emb.flatten() for emb in sentence_embs] return [emb.flatten() for emb in sentence_embs]
else: else:
@ -209,11 +228,11 @@ class EmbeddingEncoderCore:
class KVCacheManager: class KVCacheManager:
def __init__( def __init__(
self, self,
config: ModelConfig, config: ModelConfig,
batch_size: int, batch_size: int,
device: torch.device = "cuda", device: torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16 dtype: torch.dtype = torch.bfloat16,
): ):
self.batch_size = batch_size self.batch_size = batch_size
self.device = device self.device = device
@ -221,25 +240,41 @@ class KVCacheManager:
self.num_layers = config.n_layers self.num_layers = config.n_layers
self.max_len = config.max_len self.max_len = config.max_len
self.num_heads = config.n_kv_heads self.num_heads = config.n_kv_heads
self.head_dim = config.dim //config.n_heads self.head_dim = config.dim // config.n_heads
self._kv_cache: Tuple[Tensor, Tensor] = None self._kv_cache: Tuple[Tensor, Tensor] = None
self._seq_mask: Tensor = None self._seq_mask: Tensor = None
self._initialize() self._initialize()
def _initialize(self): def _initialize(self):
k_cache = torch.empty( k_cache = torch.empty(
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim), (
device=self.device, dtype=self.dtype self.batch_size,
self.max_len,
self.num_layers,
self.num_heads,
self.head_dim,
),
device=self.device,
dtype=self.dtype,
) )
v_cache = torch.empty( v_cache = torch.empty(
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim), (
device=self.device, dtype=self.dtype self.batch_size,
self.max_len,
self.num_layers,
self.num_heads,
self.head_dim,
),
device=self.device,
dtype=self.dtype,
) )
self._kv_cache = (k_cache, v_cache) self._kv_cache = (k_cache, v_cache)
self._seq_mask = torch.ones((self.batch_size, self.max_len), device=self.device, dtype=torch.bool) self._seq_mask = torch.ones(
(self.batch_size, self.max_len), device=self.device, dtype=torch.bool
)
def update(self, active_mask: Tensor): def update(self, active_mask: Tensor):
k_cache, v_cache = self._kv_cache k_cache, v_cache = self._kv_cache
self._kv_cache = (k_cache[active_mask], v_cache[active_mask]) self._kv_cache = (k_cache[active_mask], v_cache[active_mask])
self._seq_mask = self._seq_mask[active_mask] self._seq_mask = self._seq_mask[active_mask]
@ -250,14 +285,14 @@ class KVCacheManager:
self._seq_mask = None self._seq_mask = None
else: else:
self._initialize() self._initialize()
def set_seq_mask(self, input_ids: Tensor, pad_id: int): def set_seq_mask(self, input_ids: Tensor, pad_id: int):
batch_size, seq_len = input_ids.shape batch_size, seq_len = input_ids.shape
bool_mask = (input_ids != pad_id) bool_mask = input_ids != pad_id
self._seq_mask[: batch_size, : seq_len] = bool_mask self._seq_mask[:batch_size, :seq_len] = bool_mask
def get_kvcache(self) -> Tuple[Tensor, Tensor]: def get_kvcache(self) -> Tuple[Tensor, Tensor]:
return self._kv_cache return self._kv_cache
def get_seq_mask(self) -> Tensor: def get_seq_mask(self) -> Tensor:
return self._seq_mask return self._seq_mask

View File

@ -1,6 +1,6 @@
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from torch import Tensor from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator from typing import List, Tuple, Union, Optional, Generator
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
from khaosz.config.param_config import ModelParameter from khaosz.config.param_config import ModelParameter
@ -8,10 +8,11 @@ from khaosz.config.param_config import ModelParameter
HistoryType = List[Tuple[str, str]] HistoryType = List[Tuple[str, str]]
def build_prompt( def build_prompt(
query: str, query: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
history: Optional[HistoryType] = None history: Optional[HistoryType] = None,
) -> str: ) -> str:
""" """
Build prompt in ChatML format for query and history. Build prompt in ChatML format for query and history.
@ -42,17 +43,17 @@ def build_prompt(
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]: def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
""" """
Pad a list of sequences to a fixed length. Pad a list of sequences to a fixed length.
Args: Args:
ids_list (List[List[int]]): A list of sequences. ids_list (List[List[int]]): A list of sequences.
max_ids_len (int): The maximum length of sequences. max_ids_len (int): The maximum length of sequences.
pad_id (int): The id to pad sequences. pad_id (int): The id to pad sequences.
Returns: Returns:
List[List[int]]: A list of padded sequences. List[List[int]]: A list of padded sequences.
""" """
max_ids_len = max(len(ids) for ids in ids_list) max_ids_len = max(len(ids) for ids in ids_list)
new_ids_list = [] new_ids_list = []
@ -60,7 +61,7 @@ def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]
pad_len = max_ids_len - len(ids) pad_len = max_ids_len - len(ids)
padded_seq = [pad_id] * pad_len + ids padded_seq = [pad_id] * pad_len + ids
new_ids_list.append(padded_seq) new_ids_list.append(padded_seq)
return new_ids_list, max_ids_len return new_ids_list, max_ids_len
@ -68,7 +69,7 @@ def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]
class GenerationRequest: class GenerationRequest:
""" """
Request parameters for text generation. Request parameters for text generation.
Attributes: Attributes:
top_k: Top-k sampling parameter. top_k: Top-k sampling parameter.
top_p: Top-p (nucleus) sampling parameter. top_p: Top-p (nucleus) sampling parameter.
@ -79,6 +80,7 @@ class GenerationRequest:
system_prompt: System prompt for the conversation. system_prompt: System prompt for the conversation.
stream: Whether to use streaming generation. stream: Whether to use streaming generation.
""" """
top_k: int top_k: int
top_p: float top_p: float
temperature: float temperature: float
@ -101,63 +103,66 @@ class GenerationRequest:
class LoopGenerator(GeneratorCore): class LoopGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter): def __init__(self, parameter: ModelParameter):
super().__init__(parameter) super().__init__(parameter)
def generate(self, request: GenerationRequest) -> str: def generate(self, request: GenerationRequest) -> str:
device = next(self.model.parameters()).device device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, 1, device=device) cache_manager = KVCacheManager(self.config, 1, device=device)
prompt = build_prompt(request.query, request.history) prompt = build_prompt(request.query, request.history)
ids = self.tokenizer.encode(prompt) ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([ids], device=device, dtype=torch.long) input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids) start_cache_pos = len(ids)
self.model.eval() self.model.eval()
kv_caches = cache_manager.get_kvcache() kv_caches = cache_manager.get_kvcache()
ids = self.generate_loop( ids = self.generate_loop(
input_ids, input_ids,
ids, ids,
request.temperature, request.temperature,
request.top_k, request.top_k,
request.top_p, request.top_p,
kv_caches=kv_caches, kv_caches=kv_caches,
) )
response = self.tokenizer.decode(ids[start_cache_pos:]) response = self.tokenizer.decode(ids[start_cache_pos:])
return response return response
class StreamGenerator(GeneratorCore): class StreamGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter): def __init__(self, parameter: ModelParameter):
super().__init__(parameter) super().__init__(parameter)
def generate(self, request: GenerationRequest) -> Generator[str, None, None]: def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
device = next(self.model.parameters()).device device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, 1, device=device) cache_manager = KVCacheManager(self.config, 1, device=device)
prompt = build_prompt(request.query, request.history) prompt = build_prompt(request.query, request.history)
ids = self.tokenizer.encode(prompt) ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([ids], device=device, dtype=torch.long) input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids) start_cache_pos = len(ids)
cur_cache_pos = 0 cur_cache_pos = 0
self.model.eval() self.model.eval()
kv_caches = cache_manager.get_kvcache() kv_caches = cache_manager.get_kvcache()
for _ in range(len(ids), self.config.max_len): for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator( next_token_id, cache_increase = self.generate_iterator(
input_ids, request.temperature, request.top_k, request.top_p, input_ids,
kv_caches=kv_caches, request.temperature,
start_pos=cur_cache_pos request.top_k,
request.top_p,
kv_caches=kv_caches,
start_pos=cur_cache_pos,
) )
input_ids = next_token_id input_ids = next_token_id
ids.append(next_token_id.item()) ids.append(next_token_id.item())
cur_cache_pos += cache_increase cur_cache_pos += cache_increase
response = self.tokenizer.decode(ids[start_cache_pos:]) response = self.tokenizer.decode(ids[start_cache_pos:])
yield response yield response
if next_token_id.item() in self.tokenizer.stop_ids: if next_token_id.item() in self.tokenizer.stop_ids:
yield response + "\n" yield response + "\n"
break break
@ -166,131 +171,140 @@ class StreamGenerator(GeneratorCore):
class BatchGenerator(GeneratorCore): class BatchGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter): def __init__(self, parameter: ModelParameter):
super().__init__(parameter) super().__init__(parameter)
def generate(self, request: GenerationRequest) -> List[str]: def generate(self, request: GenerationRequest) -> List[str]:
batch_size = len(request.query) batch_size = len(request.query)
if request.history is None: if request.history is None:
request.history = [[] for _ in range(batch_size)] request.history = [[] for _ in range(batch_size)]
prompts = [build_prompt(query, history) for query, history in zip(request.query, request.history)] prompts = [
build_prompt(query, history)
for query, history in zip(request.query, request.history)
]
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts] ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id) ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
device = next(self.model.parameters()).device device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, batch_size, device=device) cache_manager = KVCacheManager(self.config, batch_size, device=device)
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long) input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id) cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
activate_task_mask = [True] * batch_size activate_task_mask = [True] * batch_size
start_cache_pos = max_ids_len start_cache_pos = max_ids_len
cur_cache_pos = 0 cur_cache_pos = 0
while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0: while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0:
kv_caches = cache_manager.get_kvcache() kv_caches = cache_manager.get_kvcache()
attn_mask =cache_manager.get_seq_mask() attn_mask = cache_manager.get_seq_mask()
next_token_id, cache_increase = self.generate_iterator( next_token_id, cache_increase = self.generate_iterator(
input_tensor, request.temperature, request.top_k, request.top_p, input_tensor,
attn_mask=attn_mask, request.temperature,
kv_caches=kv_caches, request.top_k,
start_pos=cur_cache_pos request.top_p,
attn_mask=attn_mask,
kv_caches=kv_caches,
start_pos=cur_cache_pos,
) )
cur_cache_pos += cache_increase cur_cache_pos += cache_increase
active_mask = [] active_mask = []
c_ids = 0 c_ids = 0
for i in range(batch_size): for i in range(batch_size):
if activate_task_mask[i]: if activate_task_mask[i]:
token = next_token_id[c_ids, :].item() token = next_token_id[c_ids, :].item()
ids_list[i].append(token) ids_list[i].append(token)
c_ids += 1 c_ids += 1
is_active = not token in self.tokenizer.stop_ids is_active = not token in self.tokenizer.stop_ids
activate_task_mask[i] = is_active activate_task_mask[i] = is_active
active_mask.append(is_active) active_mask.append(is_active)
active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool) active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool)
cache_manager.update(active_mask) cache_manager.update(active_mask)
input_tensor = next_token_id[active_mask, :] input_tensor = next_token_id[active_mask, :]
max_ids_len += 1 max_ids_len += 1
responses = [str()] * batch_size responses = [str()] * batch_size
for i in range(batch_size): for i in range(batch_size):
responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:]) responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
request.history[i].append((request.query[i], responses[i])) request.history[i].append((request.query[i], responses[i]))
return responses return responses
class EmbeddingEncoder(EmbeddingEncoderCore): class EmbeddingEncoder(EmbeddingEncoderCore):
def __init__(self, parameter: ModelParameter): def __init__(self, parameter: ModelParameter):
super().__init__(parameter) super().__init__(parameter)
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]: def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
return super().encode(sentence) return super().encode(sentence)
class GeneratorFactory: class GeneratorFactory:
"""Factory class for creating generator instances. """Factory class for creating generator instances.
Provides smart generator selection based on request characteristics: Provides smart generator selection based on request characteristics:
- Streaming: Use StreamGenerator for streaming output - Streaming: Use StreamGenerator for streaming output
- Batch: Use BatchGenerator when query is a list - Batch: Use BatchGenerator when query is a list
- Single: Use LoopGenerator for single query non-streaming - Single: Use LoopGenerator for single query non-streaming
Example usage: Example usage:
generator = GeneratorFactory.create_generator(parameter, request) generator = GeneratorFactory.create_generator(parameter, request)
result = generator.generate(request) result = generator.generate(request)
""" """
@staticmethod @staticmethod
def create_generator(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore: def create_generator(
parameter: ModelParameter, request: GenerationRequest
) -> GeneratorCore:
"""Create a generator based on request characteristics. """Create a generator based on request characteristics.
Args: Args:
parameter: Model parameters containing model, tokenizer, config parameter: Model parameters containing model, tokenizer, config
request: Generation request with query, options, etc. request: Generation request with query, options, etc.
Returns: Returns:
Appropriate GeneratorCore subclass instance Appropriate GeneratorCore subclass instance
""" """
# Streaming generation: check stream field first # Streaming generation: check stream field first
if request.stream: if request.stream:
return StreamGenerator(parameter) return StreamGenerator(parameter)
# Batch generation: query is a list of strings # Batch generation: query is a list of strings
if isinstance(request.query, list): if isinstance(request.query, list):
return BatchGenerator(parameter) return BatchGenerator(parameter)
# Default: single query non-streaming # Default: single query non-streaming
return LoopGenerator(parameter) return LoopGenerator(parameter)
@staticmethod @staticmethod
def create_encoder(parameter: ModelParameter) -> EmbeddingEncoderCore: def create_encoder(parameter: ModelParameter) -> EmbeddingEncoderCore:
"""Create an embedding encoder instance. """Create an embedding encoder instance.
Args: Args:
parameter: Model parameters parameter: Model parameters
Returns: Returns:
EmbeddingEncoderCore instance EmbeddingEncoderCore instance
""" """
return EmbeddingEncoder(parameter) return EmbeddingEncoder(parameter)
@classmethod @classmethod
def create(cls, parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore: def create(
cls, parameter: ModelParameter, request: GenerationRequest
) -> GeneratorCore:
"""Convenience method that delegates to create_generator. """Convenience method that delegates to create_generator.
Args: Args:
parameter: Model parameters parameter: Model parameters
request: Generation request request: Generation request
Returns: Returns:
Generator instance Generator instance
""" """
return cls.create_generator(parameter, request) return cls.create_generator(parameter, request)

View File

@ -1,17 +1,10 @@
from khaosz.model.module import ( from khaosz.model.module import (
Linear, Linear,
RMSNorm, RMSNorm,
MLP, MLP,
GQA, GQA,
DecoderBlock, DecoderBlock,
) )
from khaosz.model.transformer import Transformer from khaosz.model.transformer import Transformer
__all__ = [ __all__ = ["Linear", "RMSNorm", "MLP", "GQA", "DecoderBlock", "Transformer"]
"Linear",
"RMSNorm",
"MLP",
"GQA",
"DecoderBlock",
"Transformer"
]

View File

@ -7,7 +7,7 @@ from typing import Optional, Tuple
def repeat_kv(x: Tensor, n_rep: int) -> Tensor: def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
""" """
Repeat k times along the dimension for attention heads. Repeat k times along the dimension for attention heads.
Args: Args:
x (Tensor): The input tensor. x (Tensor): The input tensor.
@ -15,7 +15,7 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
Returns: Returns:
Tensor: The repeated tensor. Tensor: The repeated tensor.
""" """
bs, slen, n_heads, head_dim = x.shape bs, slen, n_heads, head_dim = x.shape
if n_rep == 1: if n_rep == 1:
return x return x
@ -25,12 +25,13 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
.reshape(bs, slen, n_heads * n_rep, head_dim) .reshape(bs, slen, n_heads * n_rep, head_dim)
) )
def get_rotary_emb( def get_rotary_emb(
dim: int, dim: int,
max_len: int, max_len: int,
base: float = 10000, base: float = 10000,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
Get the rotary embedding for the given dimension and maximum length. Get the rotary embedding for the given dimension and maximum length.
Args: Args:
dim (int): The dimension of the input. dim (int): The dimension of the input.
@ -46,6 +47,7 @@ def get_rotary_emb(
return torch.cos(freqs).float(), torch.sin(freqs).float() return torch.cos(freqs).float(), torch.sin(freqs).float()
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor: def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
""" """
Apply rotary embedding to the input tensor using cos/sin form. Apply rotary embedding to the input tensor using cos/sin form.
@ -55,49 +57,49 @@ def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tens
Returns: Returns:
Tensor: The output tensor (rotated, same shape as input). Tensor: The output tensor (rotated, same shape as input).
""" """
dtype = x.dtype dtype = x.dtype
cos, sin = rotary_emb cos, sin = rotary_emb
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2] cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2] sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
x_real = x[..., 0::2] # [batch, seq_len, dim//2] x_real = x[..., 0::2] # [batch, seq_len, dim//2]
x_imag = x[..., 1::2] # [batch, seq_len, dim//2] x_imag = x[..., 1::2] # [batch, seq_len, dim//2]
x_real_rot = x_real * cos - x_imag * sin x_real_rot = x_real * cos - x_imag * sin
x_imag_rot = x_real * sin + x_imag * cos x_imag_rot = x_real * sin + x_imag * cos
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2] x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim] x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
return x_out.to(dtype) return x_out.to(dtype)
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int, base: int=10000): def __init__(self, dim: int, max_len: int, base: int = 10000):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.max_len = max_len self.max_len = max_len
self.base = base self.base = base
self.max_len_cached = None self.max_len_cached = None
self._set_rotary_buffer(self.max_len) self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int): def _set_rotary_buffer(self, max_len: int):
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base) cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base)
self.register_buffer("cos_cached", cos_cached, persistent=False) self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False)
self.max_len_cached = max_len self.max_len_cached = max_len
def forward(self, x: Tensor, start_pos: int=0) -> Tuple[Tensor, Tensor]: def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
seq_len = x.size(1) seq_len = x.size(1)
if self.max_len_cached < seq_len + start_pos: if self.max_len_cached < seq_len + start_pos:
self._set_rotary_buffer(seq_len + start_pos) self._set_rotary_buffer(seq_len + start_pos)
cos = self.cos_cached[start_pos : start_pos + seq_len] cos = self.cos_cached[start_pos : start_pos + seq_len]
sin = self.sin_cached[start_pos : start_pos + seq_len] sin = self.sin_cached[start_pos : start_pos + seq_len]
return (cos, sin) return (cos, sin)
@ -115,43 +117,42 @@ class RMSNorm(nn.Module):
def __init__(self, dim, norm_eps): def __init__(self, dim, norm_eps):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(dim)) self.weight = nn.Parameter(torch.ones(dim))
self.normalized_shape = (dim, ) self.normalized_shape = (dim,)
self.norm_eps = norm_eps self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps) rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps)
return rms.to(x.dtype) return rms.to(x.dtype)
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, dim: int, dim_feed_forward: int): def __init__(self, dim: int, dim_feed_forward: int):
super().__init__() super().__init__()
self.up = Linear(dim, dim_feed_forward) self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward) self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(dim_feed_forward, dim) self.down = Linear(dim_feed_forward, dim)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x)) gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated) out = self.down(gated)
return out return out
class GQA(nn.Module): class GQA(nn.Module):
def __init__( def __init__(
self, self,
dim: int, dim: int,
n_heads: int, n_heads: int,
n_kv_heads: int, n_kv_heads: int,
use_qk_norm: bool, use_qk_norm: bool,
norm_eps: float, norm_eps: float,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int layer_id: int,
): ):
super().__init__() super().__init__()
assert dim % n_heads == 0 assert dim % n_heads == 0
assert n_heads % n_kv_heads == 0 assert n_heads % n_kv_heads == 0
self.head_dim = dim // n_heads self.head_dim = dim // n_heads
self.layer_id = layer_id self.layer_id = layer_id
self.dim = dim self.dim = dim
@ -165,11 +166,11 @@ class GQA(nn.Module):
self.k_proj = Linear(dim, n_kv_heads * self.head_dim) self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
self.v_proj = Linear(dim, n_kv_heads * self.head_dim) self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
self.o_proj = Linear(dim, dim) self.o_proj = Linear(dim, dim)
if self.use_qk_norm: if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, norm_eps) self.q_norm = RMSNorm(self.head_dim, norm_eps)
self.k_norm = RMSNorm(self.head_dim, norm_eps) self.k_norm = RMSNorm(self.head_dim, norm_eps)
if self.use_gated_attention: if self.use_gated_attention:
self.gate = Linear(dim, dim) self.gate = Linear(dim, dim)
@ -177,14 +178,14 @@ class GQA(nn.Module):
batch_size, seq_len, _ = x.shape batch_size, seq_len, _ = x.shape
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim) x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
return x return x
def forward( def forward(
self, self,
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None, mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0 start_pos: int = 0,
) -> Tensor: ) -> Tensor:
bsz, seq_len, _ = x.size() bsz, seq_len, _ = x.size()
is_causal = mask is None is_causal = mask is None
@ -194,31 +195,36 @@ class GQA(nn.Module):
k = self._split_heads(self.k_proj(x), self.n_kv_heads) k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kv_heads) v = self._split_heads(self.v_proj(x), self.n_kv_heads)
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb) q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
if self.use_qk_norm: if self.use_qk_norm:
q, k = self.q_norm(q), self.k_norm(k) q, k = self.q_norm(q), self.k_norm(k)
if kv_cache is not None: if kv_cache is not None:
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
# copy to cache # copy to cache
k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
# get cache # get cache
k = k_cache[:bsz, :start_pos + seq_len, self.layer_id] k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id] v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim) # (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal).permute(0, 2, 1, 3).contiguous().flatten(2) sdqa_out = (
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
.permute(0, 2, 1, 3)
.contiguous()
.flatten(2)
)
if self.use_gated_attention: if self.use_gated_attention:
sdqa_out = sdqa_out * F.sigmoid(self.gate(x)) sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
out = self.o_proj(sdqa_out) out = self.o_proj(sdqa_out)
return out return out
@ -227,15 +233,15 @@ class GQA(nn.Module):
class MLA(nn.Module): class MLA(nn.Module):
def __init__( def __init__(
self, self,
dim: int, dim: int,
n_heads: int, n_heads: int,
n_kv_heads: int, n_kv_heads: int,
kv_lora_rank: int, kv_lora_rank: int,
qk_nope_head_dim: int, qk_nope_head_dim: int,
qk_rope_head_dim: int, qk_rope_head_dim: int,
norm_eps: float, norm_eps: float,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int layer_id: int,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -252,45 +258,46 @@ class MLA(nn.Module):
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False) self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False) self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps) self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps)
# KV (k_nope, k_rope, v) # KV (k_nope, k_rope, v)
self.kv_b_proj = Linear( self.kv_b_proj = Linear(
kv_lora_rank, kv_lora_rank,
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim), n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
) )
self.o_proj = Linear(dim, dim, bias=False) self.o_proj = Linear(dim, dim, bias=False)
if use_gated_attention: if use_gated_attention:
self.gate = Linear(dim, dim, bias=False) self.gate = Linear(dim, dim, bias=False)
def forward( def forward(
self, self,
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None, mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0 start_pos: int = 0,
) -> Tensor: ) -> Tensor:
bsz, seq_len, _ = x.size() bsz, seq_len, _ = x.size()
is_causal = mask is None is_causal = mask is None
q = self.q_proj(x) q = self.q_proj(x)
q = q.view(bsz, seq_len, self.n_heads, self.head_dim) q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
kv_compressed = self.kv_a_proj(x) kv_compressed = self.kv_a_proj(x)
kv_compressed = self.kv_norm(kv_compressed) kv_compressed = self.kv_norm(kv_compressed)
kv = self.kv_b_proj(kv_compressed) kv = self.kv_b_proj(kv_compressed)
kv = kv.view(bsz, seq_len, self.n_kv_heads, -1) kv = kv.view(bsz, seq_len, self.n_kv_heads, -1)
k_nope, k_rope, v = torch.split( k_nope, k_rope, v = torch.split(
kv, kv, [self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], dim=-1
[self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], )
dim=-1
q_nope, q_rope = (
q[..., : self.qk_nope_head_dim],
q[..., self.qk_rope_head_dim :],
) )
q_nope, q_rope = q[..., :self.qk_nope_head_dim], q[..., self.qk_rope_head_dim:]
q_rope = apply_rotary_emb(q_rope, rotary_emb) q_rope = apply_rotary_emb(q_rope, rotary_emb)
k_rope = apply_rotary_emb(k_rope, rotary_emb) k_rope = apply_rotary_emb(k_rope, rotary_emb)
@ -299,41 +306,48 @@ class MLA(nn.Module):
if kv_cache is not None: if kv_cache is not None:
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
k = k_cache[:bsz, :start_pos + seq_len, self.layer_id] k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id] v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
q = q.permute(0, 2, 1, 3) q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3)
attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal) attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2) attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
if self.use_gated_attention: if self.use_gated_attention:
attn_out = attn_out * F.sigmoid(self.gate(x)) attn_out = attn_out * F.sigmoid(self.gate(x))
out = self.o_proj(attn_out) out = self.o_proj(attn_out)
return out return out
class DecoderBlock(nn.Module): class DecoderBlock(nn.Module):
def __init__( def __init__(
self, self,
dim: int, dim: int,
n_heads: int, n_heads: int,
dim_ffn: int, dim_ffn: int,
n_kv_heads: int, n_kv_heads: int,
norm_eps: int, norm_eps: int,
use_qk_norm: bool, use_qk_norm: bool,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int layer_id: int,
): ):
super().__init__() super().__init__()
self.attention = GQA(dim, n_heads, n_kv_heads, self.attention = GQA(
use_qk_norm, norm_eps, use_gated_attention, layer_id) dim,
n_heads,
n_kv_heads,
use_qk_norm,
norm_eps,
use_gated_attention,
layer_id,
)
self.input_norm = RMSNorm(dim, norm_eps) self.input_norm = RMSNorm(dim, norm_eps)
self.mlp = MLP(dim, dim_ffn) self.mlp = MLP(dim, dim_ffn)
self.post_attention_norm = RMSNorm(dim, norm_eps) self.post_attention_norm = RMSNorm(dim, norm_eps)
@ -341,24 +355,20 @@ class DecoderBlock(nn.Module):
def forward( def forward(
self, self,
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0 start_pos: int = 0,
) -> Tensor: ) -> Tensor:
# attention # attention
attn_output = self.attention( attn_output = self.attention(
self.input_norm(x), self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos
rotary_emb,
attention_mask,
kv_cache,
start_pos
) )
x = attn_output + x x = attn_output + x
# feed forward # feed forward
x = self.mlp(self.post_attention_norm(x)) + x x = self.mlp(self.post_attention_norm(x)) + x
return x return x
@ -366,6 +376,6 @@ class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int): def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim))) self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight) return F.embedding(x, self.weight)

View File

@ -4,15 +4,21 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from typing import Any, Mapping, Optional, Tuple from typing import Any, Mapping, Optional, Tuple
from khaosz.config.model_config import ModelConfig from khaosz.config.model_config import ModelConfig
from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, RotaryEmbedding from khaosz.model.module import (
Embedding,
DecoderBlock,
Linear,
RMSNorm,
RotaryEmbedding,
)
def process_attention_mask( def process_attention_mask(
seq_mask: Tensor, seq_mask: Tensor,
input_tensor: Tensor, input_tensor: Tensor,
start_pos: int = 0, start_pos: int = 0,
is_causal: bool = False, is_causal: bool = False,
) -> Tensor: ) -> Tensor:
""" """
Create attention mask for GQA Create attention mask for GQA
Args: Args:
@ -26,32 +32,36 @@ def process_attention_mask(
device = input_tensor.device device = input_tensor.device
dtype = input_tensor.dtype dtype = input_tensor.dtype
seq_len = input_tensor.size(1) seq_len = input_tensor.size(1)
if seq_mask is None: if seq_mask is None:
if start_pos != 0: if start_pos != 0:
# for single prompt chat # for single prompt chat
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
else: else:
return None return None
if seq_mask.dim() > 2: if seq_mask.dim() > 2:
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos) # shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
# if ndim > 2, it's 4D tensor # if ndim > 2, it's 4D tensor
return seq_mask return seq_mask
batch_size = seq_mask.size(0) batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool) seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len) # (bsz, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len) expanded_mask = seq_mask.unsqueeze(1).expand(
batch_size, seq_len, start_pos + seq_len
)
# (bsz, seq_len, start_pos + seq_len) # (bsz, seq_len, start_pos + seq_len)
if is_causal: if is_causal:
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos) expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device) attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1) attention_mask = attention_mask.masked_fill_(
~expanded_mask, -torch.finfo(dtype).max / 2
).unsqueeze(1)
# (bsz, 1, seq_len, seq_len + start_pos) # (bsz, 1, seq_len, seq_len + start_pos)
return attention_mask return attention_mask
@ -59,26 +69,38 @@ class Transformer(nn.Module):
def __init__(self, config: ModelConfig): def __init__(self, config: ModelConfig):
super().__init__() super().__init__()
self.config = config self.config = config
self.rotary_embeding = RotaryEmbedding(config.dim // config.n_heads, config.max_len) self.rotary_embeding = RotaryEmbedding(
config.dim // config.n_heads, config.max_len
)
self.embed_tokens = Embedding(config.vocab_size, config.dim) self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList(
DecoderBlock(config.dim, config.n_heads, config.dim_ffn, config.n_kv_heads, [
config.norm_eps, config.use_qk_norm, config.use_gated_attention, layer_id) DecoderBlock(
for layer_id in range(config.n_layers) config.dim,
]) config.n_heads,
config.dim_ffn,
config.n_kv_heads,
config.norm_eps,
config.use_qk_norm,
config.use_gated_attention,
layer_id,
)
for layer_id in range(config.n_layers)
]
)
self.norm = RMSNorm(config.dim, config.norm_eps) self.norm = RMSNorm(config.dim, config.norm_eps)
self.lm_head = Linear(config.dim, config.vocab_size) self.lm_head = Linear(config.dim, config.vocab_size)
if self.config.tie_weight == True: if self.config.tie_weight == True:
self.lm_head.weight = self.embed_tokens.weight self.lm_head.weight = self.embed_tokens.weight
self._init_parameters() self._init_parameters()
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False): def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
lm_head_key = 'lm_head.weight' lm_head_key = "lm_head.weight"
embed_key = 'embed_tokens.weight' embed_key = "embed_tokens.weight"
if self.config.tie_weight == True: if self.config.tie_weight == True:
# same tensor # same tensor
@ -87,48 +109,44 @@ class Transformer(nn.Module):
if lm_head_key not in state_dict and embed_key in state_dict: if lm_head_key not in state_dict and embed_key in state_dict:
# use clone to avoid sharing the same tensor # use clone to avoid sharing the same tensor
state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
return super().load_state_dict(state_dict, strict, assign) return super().load_state_dict(state_dict, strict, assign)
def state_dict(self, destination=None, prefix='', keep_vars=False): def state_dict(self, destination=None, prefix="", keep_vars=False):
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) state_dict = super().state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
if self.config.tie_weight == True: if self.config.tie_weight == True:
lm_head_key = prefix + 'lm_head.weight' lm_head_key = prefix + "lm_head.weight"
if lm_head_key in state_dict: if lm_head_key in state_dict:
del state_dict[lm_head_key] del state_dict[lm_head_key]
return state_dict return state_dict
def _init_parameters(self): def _init_parameters(self):
for param in self.parameters(): for param in self.parameters():
if param.dim() > 1: if param.dim() > 1:
nn.init.normal_(param, mean=0.0, std=0.006) nn.init.normal_(param, mean=0.0, std=0.006)
def forward( def forward(
self, self,
input_ids: Tensor, input_ids: Tensor,
input_mask: Optional[Tensor]=None, input_mask: Optional[Tensor] = None,
persistent_key_values: Optional[Tuple[Tensor, Tensor]]=None, persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0 start_pos: int = 0,
) -> Tensor: ) -> Tensor:
assert input_ids.ndim == 2 assert input_ids.ndim == 2
x = self.embed_tokens(input_ids) x = self.embed_tokens(input_ids)
rotary_emb = self.rotary_embeding(x, start_pos) rotary_emb = self.rotary_embeding(x, start_pos)
attn_mask = process_attention_mask( attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
input_mask, x, start_pos, is_causal=True
)
for layer in self.layers: for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos) x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos)
hidden_states = self.norm(x) hidden_states = self.norm(x)
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
return { return {"logits": logits, "hidden_states": hidden_states}
"logits": logits,
"hidden_states": hidden_states
}

View File

@ -1,27 +1,21 @@
from khaosz.parallel.setup import ( from khaosz.parallel.setup import (
get_world_size, get_world_size,
get_rank, get_rank,
get_current_device, get_current_device,
only_on_rank, only_on_rank,
setup_parallel, setup_parallel,
spawn_parallel_fn spawn_parallel_fn,
) )
from khaosz.parallel.module import ( from khaosz.parallel.module import RowParallelLinear, ColumnParallelLinear
RowParallelLinear,
ColumnParallelLinear
)
__all__ = [ __all__ = [
"get_world_size", "get_world_size",
"get_rank", "get_rank",
"get_current_device", "get_current_device",
"only_on_rank", "only_on_rank",
"setup_parallel", "setup_parallel",
"spawn_parallel_fn", "spawn_parallel_fn",
"RowParallelLinear", "RowParallelLinear",
"ColumnParallelLinear" "ColumnParallelLinear",
] ]

View File

@ -17,91 +17,99 @@ class ParallelModel(nn.Module):
class RowParallelLinear(ParallelModel): class RowParallelLinear(ParallelModel):
def __init__( def __init__(
self, self,
process_group: dist.ProcessGroup, process_group: dist.ProcessGroup,
in_features: int, in_features: int,
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
reduce_results: bool = True reduce_results: bool = True,
): ):
super().__init__(process_group) super().__init__(process_group)
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.in_features_per_rank = in_features // self.world_size self.in_features_per_rank = in_features // self.world_size
self.reduce_results = reduce_results self.reduce_results = reduce_results
if in_features % self.world_size != 0: if in_features % self.world_size != 0:
raise ValueError(f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}") raise ValueError(
f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}"
)
self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank)) self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank))
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
output = F.linear(input, self.weight) output = F.linear(input, self.weight)
if self.reduce_results: if self.reduce_results:
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group) dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
if self.bias is not None: if self.bias is not None:
output += self.bias output += self.bias
return output return output
def load_state_dict(self, state_dict: Dict[str, Tensor]): def load_state_dict(self, state_dict: Dict[str, Tensor]):
full_weight = state_dict.get('weight') full_weight = state_dict.get("weight")
full_bias = state_dict.get('bias') full_bias = state_dict.get("bias")
start_idx = self.rank * self.in_features_per_rank start_idx = self.rank * self.in_features_per_rank
end_idx = start_idx + self.in_features_per_rank end_idx = start_idx + self.in_features_per_rank
weight_slice = full_weight[:, start_idx:end_idx] weight_slice = full_weight[:, start_idx:end_idx]
self.weight.data.copy_(weight_slice) self.weight.data.copy_(weight_slice)
if self.bias is not None: if self.bias is not None:
self.bias.data.copy_(full_bias) self.bias.data.copy_(full_bias)
class ColumnParallelLinear(ParallelModel): class ColumnParallelLinear(ParallelModel):
def __init__( def __init__(
self, self,
process_group: dist.ProcessGroup, process_group: dist.ProcessGroup,
in_features: int, in_features: int,
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
gather_results: bool = True gather_results: bool = True,
): ):
super().__init__(process_group) super().__init__(process_group)
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.out_features_per_rank = out_features // self.world_size self.out_features_per_rank = out_features // self.world_size
self.gather_results = gather_results self.gather_results = gather_results
if out_features % self.world_size != 0: if out_features % self.world_size != 0:
raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}") raise ValueError(
f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}"
)
self.weight = nn.Parameter(
torch.empty(self.out_features_per_rank, self.in_features)
)
self.bias = (
nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None
)
self.weight = nn.Parameter(torch.empty(self.out_features_per_rank, self.in_features))
self.bias = nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
output = F.linear(input, self.weight, self.bias) output = F.linear(input, self.weight, self.bias)
if self.gather_results: if self.gather_results:
output_list = [torch.empty_like(output) for _ in range(self.world_size)] output_list = [torch.empty_like(output) for _ in range(self.world_size)]
dist.all_gather(output_list, output, group=self.process_group) dist.all_gather(output_list, output, group=self.process_group)
output = torch.cat(output_list, dim=-1) output = torch.cat(output_list, dim=-1)
return output return output
def load_state_dict(self, state_dict: Dict[str, Tensor]): def load_state_dict(self, state_dict: Dict[str, Tensor]):
full_weight = state_dict.get('weight') full_weight = state_dict.get("weight")
full_bias = state_dict.get('bias') full_bias = state_dict.get("bias")
start_idx = self.rank * self.out_features_per_rank start_idx = self.rank * self.out_features_per_rank
end_idx = start_idx + self.out_features_per_rank end_idx = start_idx + self.out_features_per_rank
weight_slice = full_weight[start_idx:end_idx, :] weight_slice = full_weight[start_idx:end_idx, :]
self.weight.data.copy_(weight_slice) self.weight.data.copy_(weight_slice)
if self.bias is not None: if self.bias is not None:
bias_slice = full_bias[start_idx:end_idx] bias_slice = full_bias[start_idx:end_idx]
self.bias.data.copy_(bias_slice) self.bias.data.copy_(bias_slice)

View File

@ -11,73 +11,74 @@ from typing import Callable, List, Optional
def get_current_device(): def get_current_device():
return os.environ["LOCAL_DEVICE"] return os.environ["LOCAL_DEVICE"]
def get_world_size() -> int: def get_world_size() -> int:
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
return dist.get_world_size() return dist.get_world_size()
else: else:
return 1 return 1
def get_rank() -> int: def get_rank() -> int:
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
return dist.get_rank() return dist.get_rank()
else: else:
return 0 return 0
@contextmanager @contextmanager
def setup_parallel( def setup_parallel(
rank: int, rank: int,
world_size: int, world_size: int,
backend: str = "nccl", backend: str = "nccl",
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: str = "29500", master_port: str = "29500",
device_type: str = "cuda", device_type: str = "cuda",
device_ids: Optional[List[int]] = None device_ids: Optional[List[int]] = None,
): ):
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
yield dist.group.WORLD yield dist.group.WORLD
return return
if world_size <= 1: if world_size <= 1:
yield None yield None
return return
if device_ids is None: if device_ids is None:
device_ids = [i for i in range(world_size)] device_ids = [i for i in range(world_size)]
rank = device_ids[rank % len(device_ids)] rank = device_ids[rank % len(device_ids)]
device_id = torch.device(device_type, device_ids[rank]) device_id = torch.device(device_type, device_ids[rank])
os.environ['MASTER_ADDR'] = master_addr os.environ["MASTER_ADDR"] = master_addr
os.environ['MASTER_PORT'] = master_port os.environ["MASTER_PORT"] = master_port
os.environ['LOCAL_RANK'] = str(rank) os.environ["LOCAL_RANK"] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size) os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_DEVICE"] = str(device_id) os.environ["LOCAL_DEVICE"] = str(device_id)
dist.init_process_group( dist.init_process_group(
rank=rank, rank=rank, world_size=world_size, backend=backend, device_id=device_id
world_size=world_size,
backend=backend,
device_id=device_id
) )
try: try:
if backend == "nccl" and torch.cuda.is_available(): if backend == "nccl" and torch.cuda.is_available():
torch.cuda.set_device(device_id) torch.cuda.set_device(device_id)
elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available(): elif backend == "ccl" and hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.set_device(device_id) torch.xpu.set_device(device_id)
yield dist.group.WORLD yield dist.group.WORLD
finally: finally:
if dist.is_initialized(): if dist.is_initialized():
dist.destroy_process_group() dist.destroy_process_group()
def only_on_rank(rank, sync=False): def only_on_rank(rank, sync=False):
""" """
decorator to run a function only on a specific rank. decorator to run a function only on a specific rank.
""" """
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -89,67 +90,81 @@ def only_on_rank(rank, sync=False):
dist.barrier() dist.barrier()
return ret_args return ret_args
return wrapper return wrapper
return decorator return decorator
def wrapper_spawn_func( def wrapper_spawn_func(
rank: int, rank: int,
world_size: int, world_size: int,
backend: str, backend: str,
master_addr: str, master_addr: str,
master_port: str, master_port: str,
device_type: str, device_type: str,
device_ids: List[int], device_ids: List[int],
func: Callable, func: Callable,
kwargs: dict kwargs: dict,
): ):
try: try:
with setup_parallel( with setup_parallel(
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
backend=backend, backend=backend,
master_addr=master_addr, master_addr=master_addr,
master_port=master_port, master_port=master_port,
device_type=device_type, device_type=device_type,
device_ids=device_ids device_ids=device_ids,
): ):
func(**kwargs) func(**kwargs)
except Exception as e: except Exception as e:
print(f"Error in rank {rank}: {e}") print(f"Error in rank {rank}: {e}")
raise raise
def spawn_parallel_fn( def spawn_parallel_fn(
func: Callable, func: Callable,
world_size: int, world_size: int,
backend: str = "nccl", backend: str = "nccl",
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: str = "29500", master_port: str = "29500",
device_type: str = "cuda", device_type: str = "cuda",
device_ids: Optional[List[int]] = None, device_ids: Optional[List[int]] = None,
**kwargs **kwargs,
): ):
# clear environment variables # clear environment variables
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK', 'LOCAL_DEVICE']: for key in [
"MASTER_ADDR",
"MASTER_PORT",
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
"LOCAL_DEVICE",
]:
if key in os.environ: if key in os.environ:
del os.environ[key] del os.environ[key]
if world_size == 1: if world_size == 1:
device_ids = device_ids or [0] device_ids = device_ids or [0]
device_id = torch.device(device_type, device_ids[0]) device_id = torch.device(device_type, device_ids[0])
os.environ["LOCAL_DEVICE"] = str(device_id) os.environ["LOCAL_DEVICE"] = str(device_id)
func(**kwargs) func(**kwargs)
return return
wrapper_spawn_func_args = (world_size, backend, master_addr, master_port, wrapper_spawn_func_args = (
device_type, device_ids, func, kwargs) world_size,
backend,
master_addr,
master_port,
device_type,
device_ids,
func,
kwargs,
)
mp.spawn( mp.spawn(
wrapper_spawn_func, wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True
nprocs=world_size, )
args=wrapper_spawn_func_args,
join=True
)

View File

@ -14,15 +14,12 @@ from khaosz.trainer.train_callback import (
__all__ = [ __all__ = [
# Main trainer # Main trainer
"Trainer", "Trainer",
# Strategy factory # Strategy factory
"StrategyFactory", "StrategyFactory",
"BaseStrategy", "BaseStrategy",
# Scheduler factory # Scheduler factory
"SchedulerFactory", "SchedulerFactory",
"BaseScheduler", "BaseScheduler",
# Callbacks # Callbacks
"TrainCallback", "TrainCallback",
"GradientClippingCallback", "GradientClippingCallback",
@ -30,4 +27,4 @@ __all__ = [
"CheckpointCallback", "CheckpointCallback",
"ProgressBarCallback", "ProgressBarCallback",
"MetricLoggerCallback", "MetricLoggerCallback",
] ]

View File

@ -1,8 +1,9 @@
import torch.nn as nn import torch.nn as nn
from typing import Dict from typing import Dict
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]: def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
""" Compute gradient norm for each parameter in the model. """ """Compute gradient norm for each parameter in the model."""
norms = {} norms = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
norms[name] = 0.0 norms[name] = 0.0
@ -11,8 +12,9 @@ def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
norms[name] = norm norms[name] = norm
return norms return norms
def grad_std(model: nn.Module) -> Dict[str, float]: def grad_std(model: nn.Module) -> Dict[str, float]:
""" Compute standard deviation of gradients for each parameter. """ """Compute standard deviation of gradients for each parameter."""
stds = {} stds = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
stds[name] = 0.0 stds[name] = 0.0
@ -21,41 +23,45 @@ def grad_std(model: nn.Module) -> Dict[str, float]:
stds[name] = std stds[name] = std
return stds return stds
def grad_max(model: nn.Module) -> Dict[str, float]: def grad_max(model: nn.Module) -> Dict[str, float]:
""" Find the maximum absolute gradient value for each parameter. """ """Find the maximum absolute gradient value for each parameter."""
max_vals = {} max_vals = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
max_vals[name] = -float('inf') max_vals[name] = -float("inf")
if param.grad: if param.grad:
max_val = param.grad.data.max().item() max_val = param.grad.data.max().item()
max_vals[name] = max_val max_vals[name] = max_val
return max_vals return max_vals
def grad_min(model: nn.Module) -> Dict[str, float]: def grad_min(model: nn.Module) -> Dict[str, float]:
""" Find the minimum absolute gradient value for each parameter. """ """Find the minimum absolute gradient value for each parameter."""
min_vals = {} min_vals = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
min_vals[name] = float('inf') min_vals[name] = float("inf")
if param.grad: if param.grad:
min_val = param.grad.data.min().item() min_val = param.grad.data.min().item()
min_vals[name] = min_val min_vals[name] = min_val
return min_vals return min_vals
def grad_mean(model: nn.Module) -> Dict[str, float]: def grad_mean(model: nn.Module) -> Dict[str, float]:
""" Compute mean of gradients for each parameter. """ """Compute mean of gradients for each parameter."""
means = {} means = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
means[name] = 0.0 means[name] = 0.0
if param.grad: if param.grad:
mean = param.grad.data.mean().item() mean = param.grad.data.mean().item()
means[name] = mean means[name] = mean
return means return means
def grad_nan_num(model: nn.Module) -> Dict[str, int]: def grad_nan_num(model: nn.Module) -> Dict[str, int]:
""" Count the number of NaNs in gradients for each parameter. """ """Count the number of NaNs in gradients for each parameter."""
nan_nums = {} nan_nums = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
nan_nums[name] = 0 nan_nums[name] = 0
@ -64,26 +70,34 @@ def grad_nan_num(model: nn.Module) -> Dict[str, int]:
nan_nums[name] = nan_num nan_nums[name] = nan_num
return nan_nums return nan_nums
def ctx_get_loss(ctx): def ctx_get_loss(ctx):
return ctx.loss return ctx.loss
def ctx_get_lr(ctx): def ctx_get_lr(ctx):
return ctx.optimizer.param_groups[-1]['lr'] return ctx.optimizer.param_groups[-1]["lr"]
def ctx_get_grad_norm(ctx): def ctx_get_grad_norm(ctx):
return grad_norm(ctx.model) return grad_norm(ctx.model)
def ctx_get_grad_std(ctx): def ctx_get_grad_std(ctx):
return grad_std(ctx.model) return grad_std(ctx.model)
def ctx_get_grad_max(ctx): def ctx_get_grad_max(ctx):
return grad_max(ctx.model) return grad_max(ctx.model)
def ctx_get_grad_min(ctx): def ctx_get_grad_min(ctx):
return grad_min(ctx.model) return grad_min(ctx.model)
def ctx_get_grad_mean(ctx): def ctx_get_grad_mean(ctx):
return grad_mean(ctx.model) return grad_mean(ctx.model)
def ctx_get_grad_nan_num(ctx): def ctx_get_grad_nan_num(ctx):
return grad_nan_num(ctx.model) return grad_nan_num(ctx.model)

View File

@ -9,71 +9,75 @@ from khaosz.config.schedule_config import ScheduleConfig
class BaseScheduler(LRScheduler, ABC): class BaseScheduler(LRScheduler, ABC):
"""Base scheduler class for all other schedulers.""" """Base scheduler class for all other schedulers."""
def __init__(self, optimizer, last_epoch: int = -1): def __init__(self, optimizer, last_epoch: int = -1):
super().__init__(optimizer, last_epoch) super().__init__(optimizer, last_epoch)
@abstractmethod @abstractmethod
def get_lr(self) -> List[float]: def get_lr(self) -> List[float]:
"""Calculate the current learning rate.""" """Calculate the current learning rate."""
raise NotImplementedError raise NotImplementedError
def state_dict(self) -> Dict[str, Any]: def state_dict(self) -> Dict[str, Any]:
return super().state_dict() return super().state_dict()
def load_state_dict(self, state_dict: Dict[str, Any]): def load_state_dict(self, state_dict: Dict[str, Any]):
super().load_state_dict(state_dict) super().load_state_dict(state_dict)
class SchedulerFactory: class SchedulerFactory:
"""Factory class for creating learning rate schedulers. """Factory class for creating learning rate schedulers.
Supports decorator-based registration for extensible scheduler types. Supports decorator-based registration for extensible scheduler types.
Also supports creation from ScheduleConfig objects. Also supports creation from ScheduleConfig objects.
Example usage: Example usage:
@SchedulerFactory.register("custom") @SchedulerFactory.register("custom")
class CustomScheduler(BaseScheduler): class CustomScheduler(BaseScheduler):
... ...
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs) scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs)
# Or from config # Or from config
config = CosineScheduleConfig(total_steps=10000) config = CosineScheduleConfig(total_steps=10000)
scheduler = SchedulerFactory.load(optimizer, config) scheduler = SchedulerFactory.load(optimizer, config)
""" """
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {} SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {}
@classmethod @classmethod
def register(cls, name: str): def register(cls, name: str):
"""Decorator to register a new scheduler class. """Decorator to register a new scheduler class.
Args: Args:
name: Registration name for the scheduler name: Registration name for the scheduler
Returns: Returns:
Decorator function that registers the scheduler class Decorator function that registers the scheduler class
""" """
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]: def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
if not issubclass(scheduler_cls, BaseScheduler): if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler") raise TypeError(
f"{scheduler_cls.__name__} must inherit from BaseScheduler"
)
cls.SCHEDULER_MAP[name] = scheduler_cls cls.SCHEDULER_MAP[name] = scheduler_cls
return scheduler_cls return scheduler_cls
return decorator return decorator
@classmethod @classmethod
def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler: def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler:
"""Create a scheduler instance by type name. """Create a scheduler instance by type name.
Args: Args:
optimizer: PyTorch optimizer optimizer: PyTorch optimizer
schedule_type: Type of scheduler ("cosine", "sgdr") schedule_type: Type of scheduler ("cosine", "sgdr")
**kwargs: Arguments passed to the scheduler constructor **kwargs: Arguments passed to the scheduler constructor
Returns: Returns:
Scheduler instance Scheduler instance
Raises: Raises:
ValueError: If schedule_type is not supported ValueError: If schedule_type is not supported
""" """
@ -82,25 +86,25 @@ class SchedulerFactory:
f"Unknown schedule type: '{schedule_type}'. " f"Unknown schedule type: '{schedule_type}'. "
f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}" f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}"
) )
scheduler_cls = cls.SCHEDULER_MAP[schedule_type] scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
return scheduler_cls(optimizer, **kwargs) return scheduler_cls(optimizer, **kwargs)
@staticmethod @staticmethod
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler: def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
"""Create a scheduler from a ScheduleConfig object. """Create a scheduler from a ScheduleConfig object.
Args: Args:
optimizer: PyTorch optimizer optimizer: PyTorch optimizer
schedule_config: ScheduleConfig instance schedule_config: ScheduleConfig instance
Returns: Returns:
Scheduler instance Scheduler instance
""" """
kwargs = schedule_config.get_kwargs() kwargs = schedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type") schedule_type = kwargs.pop("schedule_type")
return SchedulerFactory.create(optimizer, schedule_type, **kwargs) return SchedulerFactory.create(optimizer, schedule_type, **kwargs)
@classmethod @classmethod
def available_types(cls) -> list: def available_types(cls) -> list:
"""Return list of registered scheduler type names.""" """Return list of registered scheduler type names."""
@ -114,22 +118,21 @@ class SchedulerFactory:
@SchedulerFactory.register("cosine") @SchedulerFactory.register("cosine")
class CosineScheduler(BaseScheduler): class CosineScheduler(BaseScheduler):
"""Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler.""" """Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler."""
def __init__( def __init__(
self, self,
optimizer, optimizer,
warmup_steps: int, warmup_steps: int,
lr_decay_steps: int, lr_decay_steps: int,
min_rate: float = 0.05, min_rate: float = 0.05,
last_epoch: int = -1 last_epoch: int = -1,
): ):
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
self.lr_decay_steps = lr_decay_steps self.lr_decay_steps = lr_decay_steps
self.min_rate = min_rate self.min_rate = min_rate
self.total_steps = warmup_steps + lr_decay_steps self.total_steps = warmup_steps + lr_decay_steps
super().__init__(optimizer, last_epoch) super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]: def get_lr(self) -> List[float]:
# warmup # warmup
if self.last_epoch < self.warmup_steps: if self.last_epoch < self.warmup_steps:
@ -142,46 +145,47 @@ class CosineScheduler(BaseScheduler):
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress)) cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress))
decay_factor = max(self.min_rate, cosine_decay) decay_factor = max(self.min_rate, cosine_decay)
return [base_lr * decay_factor for base_lr in self.base_lrs] return [base_lr * decay_factor for base_lr in self.base_lrs]
def state_dict(self): def state_dict(self):
state = super().state_dict() state = super().state_dict()
state.update({ state.update(
'warmup_steps': self.warmup_steps, {
'lr_decay_steps': self.lr_decay_steps, "warmup_steps": self.warmup_steps,
'min_rate': self.min_rate, "lr_decay_steps": self.lr_decay_steps,
'total_steps': self.total_steps, "min_rate": self.min_rate,
}) "total_steps": self.total_steps,
}
)
return state return state
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
self.warmup_steps = state_dict.pop('warmup_steps') self.warmup_steps = state_dict.pop("warmup_steps")
self.lr_decay_steps = state_dict.pop('lr_decay_steps') self.lr_decay_steps = state_dict.pop("lr_decay_steps")
self.min_rate = state_dict.pop('min_rate') self.min_rate = state_dict.pop("min_rate")
self.total_steps = state_dict.pop('total_steps') self.total_steps = state_dict.pop("total_steps")
super().load_state_dict(state_dict) super().load_state_dict(state_dict)
@SchedulerFactory.register("sgdr") @SchedulerFactory.register("sgdr")
class SGDRScheduler(BaseScheduler): class SGDRScheduler(BaseScheduler):
"""SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler.""" """SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler."""
def __init__( def __init__(
self, self,
optimizer, optimizer,
warmup_steps: int, warmup_steps: int,
cycle_length: int, cycle_length: int,
min_rate: float = 0.05, min_rate: float = 0.05,
t_mult: int = 2, t_mult: int = 2,
last_epoch: int = -1, last_epoch: int = -1,
): ):
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
self.cycle_length = cycle_length self.cycle_length = cycle_length
self.min_rate = min_rate self.min_rate = min_rate
self.t_mult = t_mult self.t_mult = t_mult
super().__init__(optimizer, last_epoch) super().__init__(optimizer, last_epoch)
def get_lr(self): def get_lr(self):
# warmup # warmup
if self.last_epoch < self.warmup_steps: if self.last_epoch < self.warmup_steps:
@ -190,40 +194,44 @@ class SGDRScheduler(BaseScheduler):
# SGDR # SGDR
steps_since_warmup = self.last_epoch - self.warmup_steps steps_since_warmup = self.last_epoch - self.warmup_steps
# 1. Calculate current cycle and position within cycle # 1. Calculate current cycle and position within cycle
current_cycle_length = self.cycle_length current_cycle_length = self.cycle_length
total_cycles_length = 0 total_cycles_length = 0
cycle_num = 0 cycle_num = 0
while total_cycles_length + current_cycle_length <= steps_since_warmup: while total_cycles_length + current_cycle_length <= steps_since_warmup:
total_cycles_length += current_cycle_length total_cycles_length += current_cycle_length
current_cycle_length *= self.t_mult current_cycle_length *= self.t_mult
cycle_num += 1 cycle_num += 1
steps_in_cycle = steps_since_warmup - total_cycles_length steps_in_cycle = steps_since_warmup - total_cycles_length
# 2. Cosine annealing within the current cycle # 2. Cosine annealing within the current cycle
cosine_factor = 0.5 * (1 + math.cos(math.pi * steps_in_cycle / current_cycle_length)) cosine_factor = 0.5 * (
1 + math.cos(math.pi * steps_in_cycle / current_cycle_length)
)
learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor
return [base_lr * learning_rate_factor for base_lr in self.base_lrs] return [base_lr * learning_rate_factor for base_lr in self.base_lrs]
def state_dict(self): def state_dict(self):
"""Returns the state of the scheduler as a dict.""" """Returns the state of the scheduler as a dict."""
state = super().state_dict() state = super().state_dict()
state.update({ state.update(
'warmup_steps': self.warmup_steps, {
'cycle_length': self.cycle_length, "warmup_steps": self.warmup_steps,
'min_rate': self.min_rate, "cycle_length": self.cycle_length,
't_mult': self.t_mult "min_rate": self.min_rate,
}) "t_mult": self.t_mult,
}
)
return state return state
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
"""Loads the scheduler's state.""" """Loads the scheduler's state."""
self.warmup_steps = state_dict.pop('warmup_steps') self.warmup_steps = state_dict.pop("warmup_steps")
self.cycle_length = state_dict.pop('cycle_length') self.cycle_length = state_dict.pop("cycle_length")
self.min_rate = state_dict.pop('min_rate') self.min_rate = state_dict.pop("min_rate")
self.t_mult = state_dict.pop('t_mult') self.t_mult = state_dict.pop("t_mult")
super().load_state_dict(state_dict) super().load_state_dict(state_dict)

View File

@ -20,7 +20,7 @@ def unwrap_model(model: nn.Module) -> nn.Module:
def create_ref_model(model: nn.Module) -> nn.Module: def create_ref_model(model: nn.Module) -> nn.Module:
"""Create a reference model for DPO/GRPO training. """Create a reference model for DPO/GRPO training.
Handles DDP-wrapped models safely by unwrapping first, Handles DDP-wrapped models safely by unwrapping first,
then creating a deep copy with frozen gradients. then creating a deep copy with frozen gradients.
""" """
@ -37,25 +37,27 @@ def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
def get_logprobs( def get_logprobs(
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
input_ids: Tensor, input_ids: Tensor,
mask: Tensor, mask: Tensor,
reduction: str, reduction: str,
): ):
"""Compute token-wise log probabilities from model outputs. """Compute token-wise log probabilities from model outputs.
Args: Args:
model: The language model model: The language model
input_ids: Input token IDs of shape [batch_size, seq_len] input_ids: Input token IDs of shape [batch_size, seq_len]
mask: Attention mask of shape [batch_size, seq_len] mask: Attention mask of shape [batch_size, seq_len]
reduction: How to reduce over sequence dimension ("mean", "sum", "none") reduction: How to reduce over sequence dimension ("mean", "sum", "none")
Returns: Returns:
Log probabilities with reduction applied over sequence dimension Log probabilities with reduction applied over sequence dimension
""" """
allowed_reductions = ["mean", "sum", "none"] allowed_reductions = ["mean", "sum", "none"]
if reduction not in allowed_reductions: if reduction not in allowed_reductions:
raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'") raise ValueError(
f"reduction must be one of {allowed_reductions}, got '{reduction}'"
)
shifted_input_ids = input_ids[:, 1:] shifted_input_ids = input_ids[:, 1:]
shifted_mask = mask[:, 1:] shifted_mask = mask[:, 1:]
@ -64,13 +66,13 @@ def get_logprobs(
log_probs = torch.log_softmax(logits.float(), dim=-1) log_probs = torch.log_softmax(logits.float(), dim=-1)
token_logprobs = torch.gather( token_logprobs = torch.gather(
log_probs, log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
dim=-1,
index=shifted_input_ids.unsqueeze(-1)
).squeeze(-1) ).squeeze(-1)
if reduction == "mean": if reduction == "mean":
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(dim=-1).clamp(min=1.0) return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(
dim=-1
).clamp(min=1.0)
elif reduction == "sum": elif reduction == "sum":
return (token_logprobs * shifted_mask).sum(dim=-1) return (token_logprobs * shifted_mask).sum(dim=-1)
else: else:
@ -79,23 +81,25 @@ def get_logprobs(
class BaseStrategy(ABC): class BaseStrategy(ABC):
"""Abstract base class for training strategies.""" """Abstract base class for training strategies."""
def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str): def __init__(
self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str
):
self.model = model self.model = model
self.device = device self.device = device
@abstractmethod @abstractmethod
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
"""Compute loss for the given batch. """Compute loss for the given batch.
Args: Args:
batch: Dictionary containing batch tensors batch: Dictionary containing batch tensors
Returns: Returns:
Computed loss tensor Computed loss tensor
""" """
raise NotImplementedError raise NotImplementedError
def __call__(self, batch: Dict[str, Tensor]) -> Tensor: def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
"""Allow calling strategy directly as a callable.""" """Allow calling strategy directly as a callable."""
return self.compute_loss(batch) return self.compute_loss(batch)
@ -103,51 +107,55 @@ class BaseStrategy(ABC):
class StrategyFactory: class StrategyFactory:
"""Factory class for creating training strategy instances. """Factory class for creating training strategy instances.
Supports decorator-based registration for extensible strategy types. Supports decorator-based registration for extensible strategy types.
All default strategies (seq, sft, dpo, grpo) are automatically registered. All default strategies (seq, sft, dpo, grpo) are automatically registered.
Example usage: Example usage:
@StrategyFactory.register("custom") @StrategyFactory.register("custom")
class CustomStrategy(BaseStrategy): class CustomStrategy(BaseStrategy):
... ...
strategy = StrategyFactory.create(model, "custom", device) strategy = StrategyFactory.create(model, "custom", device)
""" """
SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"}) SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"})
STRATEGY_MAP: Dict[str, type] = {} STRATEGY_MAP: Dict[str, type] = {}
@classmethod @classmethod
def register(cls, name: str): def register(cls, name: str):
"""Decorator to register a new strategy class. """Decorator to register a new strategy class.
Args: Args:
name: Registration name for the strategy name: Registration name for the strategy
Returns: Returns:
Decorator function that registers the strategy class Decorator function that registers the strategy class
""" """
def decorator(strategy_cls: type) -> type: def decorator(strategy_cls: type) -> type:
if not issubclass(strategy_cls, BaseStrategy): if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy") raise TypeError(
f"{strategy_cls.__name__} must inherit from BaseStrategy"
)
cls.STRATEGY_MAP[name] = strategy_cls cls.STRATEGY_MAP[name] = strategy_cls
return strategy_cls return strategy_cls
return decorator return decorator
@classmethod @classmethod
def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy: def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy:
"""Create a strategy instance based on training type. """Create a strategy instance based on training type.
Args: Args:
model: Model instance for the strategy model: Model instance for the strategy
train_type: Type of training ("seq", "sft", "dpo", "grpo") train_type: Type of training ("seq", "sft", "dpo", "grpo")
device: Device to run the strategy on device: Device to run the strategy on
**kwargs: Additional arguments passed to strategy constructor **kwargs: Additional arguments passed to strategy constructor
Returns: Returns:
Strategy instance Strategy instance
Raises: Raises:
ValueError: If train_type is not supported ValueError: If train_type is not supported
NotImplementedError: If train_type is in supported list but not implemented NotImplementedError: If train_type is in supported list but not implemented
@ -157,15 +165,15 @@ class StrategyFactory:
f"Unknown training strategy: '{train_type}'. " f"Unknown training strategy: '{train_type}'. "
f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}" f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}"
) )
if train_type not in cls.STRATEGY_MAP: if train_type not in cls.STRATEGY_MAP:
raise NotImplementedError( raise NotImplementedError(
f"Strategy '{train_type}' is supported but not yet implemented." f"Strategy '{train_type}' is supported but not yet implemented."
) )
strategy_cls = cls.STRATEGY_MAP[train_type] strategy_cls = cls.STRATEGY_MAP[train_type]
return strategy_cls(model, device, **kwargs) return strategy_cls(model, device, **kwargs)
@classmethod @classmethod
def available_strategies(cls) -> list: def available_strategies(cls) -> list:
"""Return list of registered strategy names.""" """Return list of registered strategy names."""
@ -179,77 +187,81 @@ class StrategyFactory:
@StrategyFactory.register("seq") @StrategyFactory.register("seq")
class SEQStrategy(BaseStrategy): class SEQStrategy(BaseStrategy):
"""Standard next-token prediction training strategy. """Standard next-token prediction training strategy.
Computes cross-entropy loss for next token prediction. Computes cross-entropy loss for next token prediction.
""" """
def __init__(self, model, device, label_smoothing: float = 0.0): def __init__(self, model, device, label_smoothing: float = 0.0):
super().__init__(model, device) super().__init__(model, device)
self.label_smoothing = label_smoothing self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device) batch = move_to_device(batch, self.device)
input_ids, target_ids = batch["input_ids"], batch["target_ids"] input_ids, target_ids = batch["input_ids"], batch["target_ids"]
logits = self.model(input_ids=input_ids)["logits"] logits = self.model(input_ids=input_ids)["logits"]
loss = F.cross_entropy( loss = F.cross_entropy(
input=logits.flatten(0, 1).float(), input=logits.flatten(0, 1).float(),
target=target_ids.flatten(), target=target_ids.flatten(),
label_smoothing=self.label_smoothing label_smoothing=self.label_smoothing,
) )
return loss return loss
@StrategyFactory.register("sft") @StrategyFactory.register("sft")
class SFTStrategy(BaseStrategy): class SFTStrategy(BaseStrategy):
"""Supervised Fine-tuning strategy with loss masking. """Supervised Fine-tuning strategy with loss masking.
Applies cross-entropy loss only to tokens where loss_mask is True. Applies cross-entropy loss only to tokens where loss_mask is True.
""" """
def __init__(self, model, device, label_smoothing: float = 0.0): def __init__(self, model, device, label_smoothing: float = 0.0):
super().__init__(model, device) super().__init__(model, device)
self.label_smoothing = label_smoothing self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device) batch = move_to_device(batch, self.device)
input_ids, target_ids, loss_mask = batch["input_ids"], batch["target_ids"], batch["loss_mask"] input_ids, target_ids, loss_mask = (
batch["input_ids"],
batch["target_ids"],
batch["loss_mask"],
)
ignore_index = -100 ignore_index = -100
logits = self.model(input_ids=input_ids)["logits"] logits = self.model(input_ids=input_ids)["logits"]
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
loss = F.cross_entropy( loss = F.cross_entropy(
input=logits.flatten(0, 1).float(), input=logits.flatten(0, 1).float(),
target=target_ids.flatten(), target=target_ids.flatten(),
ignore_index=ignore_index, ignore_index=ignore_index,
label_smoothing=self.label_smoothing label_smoothing=self.label_smoothing,
) )
return loss return loss
@StrategyFactory.register("dpo") @StrategyFactory.register("dpo")
class DPOStrategy(BaseStrategy): class DPOStrategy(BaseStrategy):
"""Direct Preference Optimization strategy. """Direct Preference Optimization strategy.
Implements the DPO loss from the paper "Direct Preference Optimization". Implements the DPO loss from the paper "Direct Preference Optimization".
Uses a reference model to compute KL divergence penalty. Uses a reference model to compute KL divergence penalty.
""" """
def __init__( def __init__(
self, self,
model: nn.Module, model: nn.Module,
device: str, device: str,
beta: float = 0.1, beta: float = 0.1,
reduction: str = "mean", reduction: str = "mean",
): ):
super().__init__(model, device) super().__init__(model, device)
self.ref_model = create_ref_model(model) self.ref_model = create_ref_model(model)
self.beta = beta self.beta = beta
self.reduction = reduction self.reduction = reduction
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device) batch = move_to_device(batch, self.device)
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"] chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
@ -257,17 +269,19 @@ class DPOStrategy(BaseStrategy):
contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0) contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0) contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction) log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction)
with torch.no_grad(): with torch.no_grad():
log_ref = get_logprobs(self.ref_model, contact_ids, contact_mask, self.reduction) log_ref = get_logprobs(
self.ref_model, contact_ids, contact_mask, self.reduction
log_pi_chosen = log_pi[:chosen_ids.shape[0]] )
log_pi_rejected = log_pi[chosen_ids.shape[0]:]
log_ref_chosen = log_ref[:chosen_ids.shape[0]] log_pi_chosen = log_pi[: chosen_ids.shape[0]]
log_ref_rejected = log_ref[chosen_ids.shape[0]:] log_pi_rejected = log_pi[chosen_ids.shape[0] :]
log_ref_chosen = log_ref[: chosen_ids.shape[0]]
log_ref_rejected = log_ref[chosen_ids.shape[0] :]
pi_log_ratio = log_pi_chosen - log_pi_rejected pi_log_ratio = log_pi_chosen - log_pi_rejected
ref_log_ratio = log_ref_chosen - log_ref_rejected ref_log_ratio = log_ref_chosen - log_ref_rejected
@ -280,14 +294,14 @@ class DPOStrategy(BaseStrategy):
@StrategyFactory.register("grpo") @StrategyFactory.register("grpo")
class GRPOStrategy(BaseStrategy): class GRPOStrategy(BaseStrategy):
"""Group Relative Policy Optimization strategy. """Group Relative Policy Optimization strategy.
Implements GRPO with clipping and KL penalty. Implements GRPO with clipping and KL penalty.
""" """
def __init__( def __init__(
self, self,
model: nn.Module, model: nn.Module,
device: str, device: str,
clip_eps: float = 0.2, clip_eps: float = 0.2,
kl_coef: float = 0.01, kl_coef: float = 0.01,
group_size: int = 4, group_size: int = 4,
@ -299,43 +313,47 @@ class GRPOStrategy(BaseStrategy):
self.kl_coef = kl_coef self.kl_coef = kl_coef
self.group_size = group_size self.group_size = group_size
self.reduction = reduction self.reduction = reduction
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device) batch = move_to_device(batch, self.device)
prompts = batch["prompts"] prompts = batch["prompts"]
responses = batch["responses"] responses = batch["responses"]
masks = batch["masks"] masks = batch["masks"]
rewards = batch["rewards"] rewards = batch["rewards"]
batch_size, group_size, response_len = responses.shape batch_size, group_size, response_len = responses.shape
responses_flat = responses.view(-1, response_len) responses_flat = responses.view(-1, response_len)
masks_flat = masks.view(-1, response_len) masks_flat = masks.view(-1, response_len)
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1) prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
# Shape: (batch_size * group_size, seq_len + response_len) # Shape: (batch_size * group_size, seq_len + response_len)
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1) full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1) full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
log_probs_policy = get_logprobs(self.model, full_sequences, full_masks, self.reduction) log_probs_policy = get_logprobs(
self.model, full_sequences, full_masks, self.reduction
)
log_probs_policy = log_probs_policy.view(batch_size, group_size) log_probs_policy = log_probs_policy.view(batch_size, group_size)
with torch.no_grad(): with torch.no_grad():
log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction) log_probs_ref = get_logprobs(
self.ref_model, full_sequences, full_masks, self.reduction
)
log_probs_ref = log_probs_ref.view(batch_size, group_size) log_probs_ref = log_probs_ref.view(batch_size, group_size)
# Compute advantages from rewards with normalization # Compute advantages from rewards with normalization
eps = torch.finfo(log_probs_policy.dtype).eps eps = torch.finfo(log_probs_policy.dtype).eps
mean = rewards.mean(dim=-1, keepdim=True) mean = rewards.mean(dim=-1, keepdim=True)
std = rewards.std(dim=-1, keepdim=True) std = rewards.std(dim=-1, keepdim=True)
advantages = (rewards - mean) / (std + eps) advantages = (rewards - mean) / (std + eps)
# PPO-style clipped surrogate objective # PPO-style clipped surrogate objective
ratio = torch.exp(0) # Off-policy: policy_model = old_model ratio = torch.exp(0) # Off-policy: policy_model = old_model
surr1 = ratio * advantages surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
policy_loss = -torch.min(surr1, surr2).mean() policy_loss = -torch.min(surr1, surr2).mean()
kl_penalty = self.kl_coef * (log_probs_policy - log_probs_ref).square().mean() kl_penalty = self.kl_coef * (log_probs_policy - log_probs_ref).square().mean()
total_loss = policy_loss + kl_penalty total_loss = policy_loss + kl_penalty
return total_loss return total_loss

View File

@ -18,52 +18,53 @@ from khaosz.trainer.metric_util import (
ctx_get_grad_norm, ctx_get_grad_norm,
ctx_get_grad_mean, ctx_get_grad_mean,
ctx_get_grad_std, ctx_get_grad_std,
ctx_get_grad_nan_num ctx_get_grad_nan_num,
) )
from khaosz.data.serialization import Checkpoint from khaosz.data.serialization import Checkpoint
from khaosz.trainer.train_context import TrainContext from khaosz.trainer.train_context import TrainContext
class TrainCallback(Protocol): class TrainCallback(Protocol):
""" """
Callback interface for trainer. Callback interface for trainer.
""" """
def on_train_begin(self, context: TrainContext): def on_train_begin(self, context: TrainContext):
""" Called at the beginning of training. """ """Called at the beginning of training."""
def on_train_end(self, context: TrainContext): def on_train_end(self, context: TrainContext):
""" Called at the end of training. """ """Called at the end of training."""
def on_epoch_begin(self, context: TrainContext): def on_epoch_begin(self, context: TrainContext):
""" Called at the beginning of each epoch. """ """Called at the beginning of each epoch."""
def on_epoch_end(self, context: TrainContext): def on_epoch_end(self, context: TrainContext):
""" Called at the end of each epoch. """ """Called at the end of each epoch."""
def on_step_begin(self, context: TrainContext): def on_step_begin(self, context: TrainContext):
""" Called at the beginning of each step. """ """Called at the beginning of each step."""
def on_step_end(self, context: TrainContext): def on_step_end(self, context: TrainContext):
""" Called at the end of each step.""" """Called at the end of each step."""
def on_batch_begin(self, context: TrainContext): def on_batch_begin(self, context: TrainContext):
""" Called at the beginning of each batch. """ """Called at the beginning of each batch."""
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
""" Called at the end of each batch. """ """Called at the end of each batch."""
def on_error(self, context: TrainContext): def on_error(self, context: TrainContext):
""" Called when an error occurs during training. """ """Called when an error occurs during training."""
class GradientClippingCallback(TrainCallback): class GradientClippingCallback(TrainCallback):
""" """
Gradient clipping callback for trainer. Gradient clipping callback for trainer.
""" """
def __init__(self, max_grad_norm: float): def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
def on_step_begin(self, context: TrainContext): def on_step_begin(self, context: TrainContext):
_ = context _ = context
clip_grad_norm_(context.model.parameters(), self.max_grad_norm) clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
@ -73,86 +74,95 @@ class SchedulerCallback(TrainCallback):
""" """
Scheduler callback for trainer. Scheduler callback for trainer.
""" """
def __init__(self): def __init__(self):
pass pass
def on_train_begin(self, context: TrainContext): def on_train_begin(self, context: TrainContext):
for group in context.optimizer.param_groups: for group in context.optimizer.param_groups:
if "initial_lr" not in group: if "initial_lr" not in group:
group["initial_lr"] = group["lr"] group["initial_lr"] = group["lr"]
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
if context.scheduler: if context.scheduler:
context.scheduler.step() context.scheduler.step()
class CheckpointCallback(TrainCallback): class CheckpointCallback(TrainCallback):
""" """
Checkpoint callback for trainer. Checkpoint callback for trainer.
""" """
def __init__( def __init__(
self, self,
save_dir: str, save_dir: str,
interval: int, interval: int,
weight_only: bool = False, weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
): ):
self.save_dir = save_dir self.save_dir = save_dir
self.interval = interval self.interval = interval
self.weight_only = weight_only self.weight_only = weight_only
self.state_dict_fn = state_dict_fn self.state_dict_fn = state_dict_fn
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
@only_on_rank(0) @only_on_rank(0)
def _save_checkpoint(self, context: TrainContext): def _save_checkpoint(self, context: TrainContext):
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}") save_path = os.path.join(
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict() self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
state_dict = (
self.state_dict_fn(context.model)
if self.state_dict_fn
else context.model.state_dict()
)
context.checkpoint = Checkpoint( context.checkpoint = Checkpoint(
state_dict=state_dict, state_dict=state_dict, epoch=context.epoch, iteration=context.iteration
epoch=context.epoch,
iteration=context.iteration
) )
context.checkpoint.save(save_path) context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration self.last_ckpt_iter = context.iteration
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval: if context.iteration - self.last_ckpt_iter >= self.interval:
self._save_checkpoint(context) self._save_checkpoint(context)
def on_train_end(self, context: TrainContext): def on_train_end(self, context: TrainContext):
if context.iteration != self.last_ckpt_iter: if context.iteration != self.last_ckpt_iter:
self._save_checkpoint(context) self._save_checkpoint(context)
def on_error(self, context: TrainContext): def on_error(self, context: TrainContext):
self._save_checkpoint(context) self._save_checkpoint(context)
class ProgressBarCallback(TrainCallback): class ProgressBarCallback(TrainCallback):
""" """
Progress bar callback for trainer. Progress bar callback for trainer.
""" """
def __init__(self, num_epoch: int): def __init__(self, num_epoch: int):
self.num_epoch = num_epoch self.num_epoch = num_epoch
self.progress_bar: tqdm = None self.progress_bar: tqdm = None
@only_on_rank(0) @only_on_rank(0)
def on_epoch_begin(self, context: TrainContext): def on_epoch_begin(self, context: TrainContext):
self.progress_bar = tqdm( self.progress_bar = tqdm(
context.dataloader, context.dataloader,
desc=f"Epoch {context.epoch+1}/{self.num_epoch}", desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
dynamic_ncols=True dynamic_ncols=True,
) )
@only_on_rank(0) @only_on_rank(0)
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
self.progress_bar.set_postfix({ self.progress_bar.set_postfix(
"loss": f"{context.loss:.4f}", {
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}" "loss": f"{context.loss:.4f}",
}) "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
}
)
self.progress_bar.update(1) self.progress_bar.update(1)
@only_on_rank(0) @only_on_rank(0)
def on_epoch_end(self, context: TrainContext): def on_epoch_end(self, context: TrainContext):
_ = context _ = context
@ -162,66 +172,65 @@ class ProgressBarCallback(TrainCallback):
class MetricLoggerCallback(TrainCallback): class MetricLoggerCallback(TrainCallback):
def __init__( def __init__(
self, self,
log_dir:str, log_dir: str,
save_interval:int, save_interval: int,
log_interval:int=10, log_interval: int = 10,
metrics:List[str]=None metrics: List[str] = None,
): ):
self.last_log_iter = 0 self.last_log_iter = 0
self.save_interval = save_interval self.save_interval = save_interval
self.log_interval = log_interval self.log_interval = log_interval
self.metrics = metrics or ['loss', 'lr'] self.metrics = metrics or ["loss", "lr"]
self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs" self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs"
self.log_dir.mkdir(parents=True, exist_ok=True) self.log_dir.mkdir(parents=True, exist_ok=True)
self.log_cache = [] self.log_cache = []
self._metric_funcs = { self._metric_funcs = {
'loss': ctx_get_loss, "loss": ctx_get_loss,
'lr': ctx_get_lr, "lr": ctx_get_lr,
'grad_norm': ctx_get_grad_norm, "grad_norm": ctx_get_grad_norm,
'grad_std': ctx_get_grad_std, "grad_std": ctx_get_grad_std,
'grad_max': ctx_get_grad_max, "grad_max": ctx_get_grad_max,
'grad_min': ctx_get_grad_min, "grad_min": ctx_get_grad_min,
'grad_mean': ctx_get_grad_mean, "grad_mean": ctx_get_grad_mean,
'grad_nan_num': ctx_get_grad_nan_num "grad_nan_num": ctx_get_grad_nan_num,
} }
def _get_log_data(self, context: TrainContext): def _get_log_data(self, context: TrainContext):
return { return {
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'), "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"epoch": context.epoch, "epoch": context.epoch,
"iter": context.iteration, "iter": context.iteration,
**{m: self._metric_funcs[m](context) for m in self.metrics} **{m: self._metric_funcs[m](context) for m in self.metrics},
} }
@only_on_rank(0) @only_on_rank(0)
def _add_log(self, log_data): def _add_log(self, log_data):
self.log_cache.append(log_data) self.log_cache.append(log_data)
@only_on_rank(0) @only_on_rank(0)
def _save_log(self, epoch, iter): def _save_log(self, epoch, iter):
log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl" log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl"
with open(log_file, 'w') as f: with open(log_file, "w") as f:
for log in self.log_cache: for log in self.log_cache:
f.write(json.dumps(log) + '\n') f.write(json.dumps(log) + "\n")
def on_batch_end(self, context): def on_batch_end(self, context):
if context.iteration % self.log_interval == 0: if context.iteration % self.log_interval == 0:
log_data = self._get_log_data(context) log_data = self._get_log_data(context)
self._add_log(log_data) self._add_log(log_data)
if context.iteration - self.last_log_iter >= self.save_interval: if context.iteration - self.last_log_iter >= self.save_interval:
self._save_log(context.epoch, context.iteration) self._save_log(context.epoch, context.iteration)
self.last_log_iter = context.iteration self.last_log_iter = context.iteration
def on_train_end(self, context): def on_train_end(self, context):
if context.iteration != self.last_log_iter: if context.iteration != self.last_log_iter:
self._save_log(context.epoch, context.iteration) self._save_log(context.epoch, context.iteration)
def on_error(self, context): def on_error(self, context):
self._save_log(context.epoch, context.iteration) self._save_log(context.epoch, context.iteration)

View File

@ -21,11 +21,11 @@ class TrainContext:
optimizer: Optimizer = field(default=None) optimizer: Optimizer = field(default=None)
scheduler: LRScheduler = field(default=None) scheduler: LRScheduler = field(default=None)
checkpoint: Checkpoint = field(default=None) checkpoint: Checkpoint = field(default=None)
epoch: int = field(default=0) epoch: int = field(default=0)
iteration: int = field(default=0) iteration: int = field(default=0)
loss: float = field(default=0.0) loss: float = field(default=0.0)
world_size: int = field(default=1) world_size: int = field(default=1)
rank: int = field(default=0) rank: int = field(default=0)
kwargs: dict = field(default_factory=dict) kwargs: dict = field(default_factory=dict)
@ -39,17 +39,17 @@ class TrainContextBuilder:
world_size=get_world_size(), world_size=get_world_size(),
rank=get_rank(), rank=get_rank(),
) )
device = get_current_device() device = get_current_device()
self._context.model = self._context.model.to(device=device) self._context.model = self._context.model.to(device=device)
if self.config.nprocs > 1: if self.config.nprocs > 1:
fn = self.config.parallel_wrapper fn = self.config.parallel_wrapper
self._context.model = fn(self._context.model) self._context.model = fn(self._context.model)
self._context.optimizer = self.config.optimizer_fn(self._context.model) self._context.optimizer = self.config.optimizer_fn(self._context.model)
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer) self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
if checkpoint is None: if checkpoint is None:
checkpoint = Checkpoint( checkpoint = Checkpoint(
@ -60,10 +60,10 @@ class TrainContextBuilder:
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch) self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
self._context.iteration = max(checkpoint.iteration, self.config.start_batch) self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
self._context.model.load_state_dict(checkpoint.state_dict) self._context.model.load_state_dict(checkpoint.state_dict)
self._context.checkpoint = checkpoint self._context.checkpoint = checkpoint
return self return self
def with_dataloader(self) -> Self: def with_dataloader(self) -> Self:
# fix: change batch level iteration to sample level offset # fix: change batch level iteration to sample level offset
config = self.config config = self.config
@ -72,28 +72,28 @@ class TrainContextBuilder:
data_source=config.dataset, data_source=config.dataset,
start_epoch=self._context.epoch, start_epoch=self._context.epoch,
start_iter=sampler_offset, start_iter=sampler_offset,
seed=config.random_seed seed=config.random_seed,
) )
dataloader = DataLoader( dataloader = DataLoader(
config.dataset, config.dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
sampler=resumeable_sampler, sampler=resumeable_sampler,
num_workers=config.num_workers, num_workers=config.num_workers,
pin_memory=config.pin_memory, pin_memory=config.pin_memory,
prefetch_factor=config.prefetch_factor prefetch_factor=config.prefetch_factor,
) )
self._context.dataloader = dataloader self._context.dataloader = dataloader
return self return self
def with_strategy(self) -> Self: def with_strategy(self) -> Self:
self._context.strategy = StrategyFactory.create( self._context.strategy = StrategyFactory.create(
model=self._context.model, model=self._context.model,
train_type=self.config.strategy, train_type=self.config.strategy,
device=get_current_device(), device=get_current_device(),
**self.config.extra_kwargs **self.config.extra_kwargs,
) )
return self return self
def build(self) -> TrainContext: def build(self) -> TrainContext:
return self._context return self._context

View File

@ -2,12 +2,12 @@ import logging
from typing import Optional, List from typing import Optional, List
from khaosz.config import TrainConfig from khaosz.config import TrainConfig
from khaosz.trainer.train_callback import ( from khaosz.trainer.train_callback import (
TrainCallback, TrainCallback,
ProgressBarCallback, ProgressBarCallback,
CheckpointCallback, CheckpointCallback,
MetricLoggerCallback, MetricLoggerCallback,
GradientClippingCallback, GradientClippingCallback,
SchedulerCallback SchedulerCallback,
) )
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
from khaosz.data.serialization import Checkpoint from khaosz.data.serialization import Checkpoint
@ -18,37 +18,39 @@ logger = logging.getLogger(__name__)
class Trainer: class Trainer:
def __init__( def __init__(
self, self, train_config: TrainConfig, callbacks: Optional[List[TrainCallback]] = None
train_config: TrainConfig,
callbacks: Optional[List[TrainCallback]] = None
): ):
self.train_config = train_config self.train_config = train_config
default_callbacks = self._get_default_callbacks() default_callbacks = self._get_default_callbacks()
self.callbacks = default_callbacks + callbacks if callbacks else default_callbacks self.callbacks = (
default_callbacks + callbacks if callbacks else default_callbacks
)
def _get_default_callbacks(self) -> List[TrainCallback]: def _get_default_callbacks(self) -> List[TrainCallback]:
train_config = self.train_config train_config = self.train_config
return [ return [
ProgressBarCallback(train_config.n_epoch), ProgressBarCallback(train_config.n_epoch),
CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), CheckpointCallback(train_config.ckpt_dir, train_config.ckpt_interval),
MetricLoggerCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), MetricLoggerCallback(train_config.ckpt_dir, train_config.ckpt_interval),
GradientClippingCallback(train_config.max_grad_norm), GradientClippingCallback(train_config.max_grad_norm),
SchedulerCallback(), SchedulerCallback(),
] ]
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (TrainContextBuilder(self.train_config) return (
.with_checkpoint(checkpoint) TrainContextBuilder(self.train_config)
.with_dataloader() .with_checkpoint(checkpoint)
.with_strategy() .with_dataloader()
.build()) .with_strategy()
.build()
)
def _call_callbacks(self, method_name: str, context: TrainContext): def _call_callbacks(self, method_name: str, context: TrainContext):
for callback in self.callbacks: for callback in self.callbacks:
method = getattr(callback, method_name, None) method = getattr(callback, method_name, None)
if method: if method:
method(context) method(context)
def train(self, checkpoint: Optional[Checkpoint] = None): def train(self, checkpoint: Optional[Checkpoint] = None):
config = self.train_config config = self.train_config
spawn_parallel_fn( spawn_parallel_fn(
@ -59,45 +61,45 @@ class Trainer:
master_port=config.master_port, master_port=config.master_port,
device_type=config.device_type, device_type=config.device_type,
device_ids=config.device_ids, device_ids=config.device_ids,
checkpoint=checkpoint checkpoint=checkpoint,
) )
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint: def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
context = self._build_context(checkpoint) context = self._build_context(checkpoint)
self._call_callbacks('on_train_begin', context) self._call_callbacks("on_train_begin", context)
try: try:
context.model.train() context.model.train()
# 1.epoch # 1.epoch
for epoch in range(context.epoch, self.train_config.n_epoch): for epoch in range(context.epoch, self.train_config.n_epoch):
context.epoch = epoch context.epoch = epoch
self._call_callbacks('on_epoch_begin', context) self._call_callbacks("on_epoch_begin", context)
for batch in context.dataloader: for batch in context.dataloader:
# 3. batch # 3. batch
self._call_callbacks('on_batch_begin', context) self._call_callbacks("on_batch_begin", context)
loss = context.strategy(batch) loss = context.strategy(batch)
context.loss = loss.item() context.loss = loss.item()
context.iteration += 1 context.iteration += 1
# to make the loss normalized by accumulation steps # to make the loss normalized by accumulation steps
stand_loss = loss / self.train_config.accumulation_steps stand_loss = loss / self.train_config.accumulation_steps
stand_loss.backward() stand_loss.backward()
self._call_callbacks('on_batch_end', context) self._call_callbacks("on_batch_end", context)
if context.iteration % self.train_config.accumulation_steps == 0: if context.iteration % self.train_config.accumulation_steps == 0:
# 2. step # 2. step
self._call_callbacks('on_step_begin', context) self._call_callbacks("on_step_begin", context)
context.optimizer.step() context.optimizer.step()
context.optimizer.zero_grad() context.optimizer.zero_grad()
self._call_callbacks('on_step_end', context) self._call_callbacks("on_step_end", context)
self._call_callbacks("on_epoch_end", context)
self._call_callbacks('on_epoch_end', context)
except Exception as e: except Exception as e:
logger.error(f"Training failed: {str(e)}", exc_info=True) logger.error(f"Training failed: {str(e)}", exc_info=True)
self._call_callbacks('on_error', context) self._call_callbacks("on_error", context)
raise raise
finally: finally:
self._call_callbacks('on_train_end', context) self._call_callbacks("on_train_end", context)

View File

@ -26,7 +26,7 @@ classifiers = [
urls = { Homepage = "https://github.com/ViperEkura/KHAOSZ" } urls = { Homepage = "https://github.com/ViperEkura/KHAOSZ" }
[project.optional-dependencies] [project.optional-dependencies]
dev = ["pytest==9.0.2"] dev = ["pytest==9.0.2", "ruff"]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
where = ["."] where = ["."]
@ -35,4 +35,13 @@ where = ["."]
extra-index-url = "https://download.pytorch.org/whl/cu126" extra-index-url = "https://download.pytorch.org/whl/cu126"
[tool.setuptools.dynamic] [tool.setuptools.dynamic]
version = { attr = "khaosz.__version__" } version = { attr = "khaosz.__version__" }
[tool.ruff]
target-version = "py312"
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"

View File

@ -17,14 +17,14 @@ class RandomDataset(Dataset):
self.length = length or int(np.random.randint(100, 200)) self.length = length or int(np.random.randint(100, 200))
self.max_length = max_length self.max_length = max_length
self.vocab_size = vocab_size self.vocab_size = vocab_size
def __len__(self): def __len__(self):
return self.length return self.length
def __getitem__(self, idx): def __getitem__(self, idx):
return { return {
"input_ids": torch.randint(0, self.vocab_size, (self.max_length,)), "input_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
"target_ids": torch.randint(0, self.vocab_size, (self.max_length,)) "target_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
} }
@ -33,10 +33,10 @@ class MultiTurnDataset(Dataset):
self.length = length or int(np.random.randint(100, 200)) self.length = length or int(np.random.randint(100, 200))
self.max_length = max_length self.max_length = max_length
self.vocab_size = vocab_size self.vocab_size = vocab_size
def __len__(self): def __len__(self):
return self.length return self.length
def __getitem__(self, idx): def __getitem__(self, idx):
input_ids = torch.randint(0, self.vocab_size, (self.max_length,)) input_ids = torch.randint(0, self.vocab_size, (self.max_length,))
target_ids = torch.randint(0, self.vocab_size, (self.max_length,)) target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
@ -54,18 +54,18 @@ class EarlyStoppingDataset(Dataset):
self.length = length self.length = length
self.stop_after = stop_after self.stop_after = stop_after
self.count = 0 self.count = 0
def __len__(self): def __len__(self):
return self.length return self.length
def __getitem__(self, idx): def __getitem__(self, idx):
self.count += 1 self.count += 1
if self.count == self.stop_after: if self.count == self.stop_after:
raise RuntimeError("Simulated early stopping") raise RuntimeError("Simulated early stopping")
return { return {
"input_ids": torch.randint(0, 1000, (64,)), "input_ids": torch.randint(0, 1000, (64,)),
"target_ids": torch.randint(0, 1000, (64,)) "target_ids": torch.randint(0, 1000, (64,)),
} }
@ -74,10 +74,10 @@ def base_test_env(request: pytest.FixtureRequest):
func_name = request.function.__name__ func_name = request.function.__name__
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_") test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
config_path = os.path.join(test_dir, "config.json") config_path = os.path.join(test_dir, "config.json")
n_dim_choices = [8, 16, 32] n_dim_choices = [8, 16, 32]
n_head_choices = [2, 4] n_head_choices = [2, 4]
dim = int(np.random.choice(n_dim_choices)) dim = int(np.random.choice(n_dim_choices))
n_heads = int(np.random.choice(n_head_choices)) n_heads = int(np.random.choice(n_head_choices))
n_kv_heads = n_heads // 2 n_kv_heads = n_heads // 2
@ -91,16 +91,16 @@ def base_test_env(request: pytest.FixtureRequest):
"dim_ffn": dim_ffn, "dim_ffn": dim_ffn,
"max_len": 1024, "max_len": 1024,
"n_layers": 4, "n_layers": 4,
"norm_eps": 1e-5 "norm_eps": 1e-5,
} }
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(config, f) json.dump(config, f)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
transformer_config = ModelConfig().load(config_path) transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config).to(device=device) model = Transformer(transformer_config).to(device=device)
tokenizer = BpeTokenizer() tokenizer = BpeTokenizer()
yield { yield {
"device": device, "device": device,
"test_dir": str(test_dir), "test_dir": str(test_dir),
@ -109,20 +109,23 @@ def base_test_env(request: pytest.FixtureRequest):
"model": model, "model": model,
"tokenizer": tokenizer, "tokenizer": tokenizer,
} }
shutil.rmtree(test_dir) shutil.rmtree(test_dir)
@pytest.fixture @pytest.fixture
def random_dataset(): def random_dataset():
dataset = RandomDataset() dataset = RandomDataset()
yield dataset yield dataset
@pytest.fixture @pytest.fixture
def multi_turn_dataset(): def multi_turn_dataset():
dataset = MultiTurnDataset() dataset = MultiTurnDataset()
yield dataset yield dataset
@pytest.fixture @pytest.fixture
def early_stopping_dataset(): def early_stopping_dataset():
dataset = EarlyStoppingDataset() dataset = EarlyStoppingDataset()
yield dataset yield dataset

View File

@ -7,6 +7,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
from khaosz.data.serialization import Checkpoint from khaosz.data.serialization import Checkpoint
from khaosz.parallel.setup import get_rank, spawn_parallel_fn from khaosz.parallel.setup import get_rank, spawn_parallel_fn
def test_single_process(): def test_single_process():
model = torch.nn.Linear(10, 5) model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3) optimizer = AdamW(model.parameters(), lr=1e-3)
@ -14,34 +15,31 @@ def test_single_process():
for epoch in range(3): for epoch in range(3):
for iteration in range(10): for iteration in range(10):
x = torch.randn(32, 10) x = torch.randn(32, 10)
y = torch.randn(32, 5) y = torch.randn(32, 5)
loss = model(x).mean() loss = model(x).mean()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
scheduler.step() scheduler.step()
checkpoint = Checkpoint( checkpoint = Checkpoint(state_dict=model.state_dict(), epoch=3, iteration=30)
state_dict=model.state_dict(),
epoch=3,
iteration=30
)
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir) checkpoint.save(tmpdir)
loaded_checkpoint = Checkpoint.load(tmpdir) loaded_checkpoint = Checkpoint.load(tmpdir)
assert loaded_checkpoint.epoch == 3 assert loaded_checkpoint.epoch == 3
assert loaded_checkpoint.iteration == 30 assert loaded_checkpoint.iteration == 30
def simple_training(): def simple_training():
model = torch.nn.Linear(10, 5) model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3) optimizer = AdamW(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=10) scheduler = CosineAnnealingLR(optimizer, T_max=10)
for epoch in range(2): for epoch in range(2):
for iteration in range(5): for iteration in range(5):
x = torch.randn(16, 10) x = torch.randn(16, 10)
@ -57,28 +55,23 @@ def simple_training():
epoch=2, epoch=2,
iteration=10, iteration=10,
) )
rank = get_rank() rank = get_rank()
if rank == 0: if rank == 0:
shared_dir = tempfile.mkdtemp() shared_dir = tempfile.mkdtemp()
checkpoint.save(shared_dir) checkpoint.save(shared_dir)
else: else:
shared_dir = None shared_dir = None
if dist.is_initialized(): if dist.is_initialized():
dir_list = [shared_dir] dir_list = [shared_dir]
dist.broadcast_object_list(dir_list, src=0) dist.broadcast_object_list(dir_list, src=0)
shared_dir = dir_list[0] shared_dir = dir_list[0]
loaded = Checkpoint.load(shared_dir) loaded = Checkpoint.load(shared_dir)
assert loaded.epoch == 2 assert loaded.epoch == 2
def test_multi_process(): def test_multi_process():
spawn_parallel_fn( spawn_parallel_fn(simple_training, world_size=2, backend="gloo")
simple_training,
world_size=2,
backend="gloo"
)

View File

@ -5,30 +5,32 @@ from khaosz.data.serialization import save_h5
from khaosz.data.dataset import * from khaosz.data.dataset import *
def test_dataset_loader_random_paths(base_test_env): def test_dataset_loader_random_paths(base_test_env):
"""Test dataset loader with multiple random paths""" """Test dataset loader with multiple random paths"""
test_dir = base_test_env["test_dir"] test_dir = base_test_env["test_dir"]
# Create multiple mmap dataset directories with random data # Create multiple mmap dataset directories with random data
num_files = np.random.randint(2, 5) num_files = np.random.randint(2, 5)
for i in range(num_files): for i in range(num_files):
seq_length = np.random.randint(200, 400) seq_length = np.random.randint(200, 400)
dummy_data = { dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64) for _ in range(10)], "sequence": [
torch.randint(0, 1000, (seq_length,), dtype=torch.int64)
for _ in range(10)
],
} }
save_h5(test_dir, f"data_{i}", dummy_data) save_h5(test_dir, f"data_{i}", dummy_data)
# Test loading with multiple paths # Test loading with multiple paths
loaded_dataset = DatasetLoader.load( loaded_dataset = DatasetLoader.load(
train_type="seq", train_type="seq",
load_path=test_dir, load_path=test_dir,
window_size=64, window_size=64,
) )
assert loaded_dataset is not None assert loaded_dataset is not None
assert len(loaded_dataset) > 0 assert len(loaded_dataset) > 0
# Test that we can get items without errors # Test that we can get items without errors
for i in range(len(loaded_dataset)): for i in range(len(loaded_dataset)):
item = loaded_dataset[i] item = loaded_dataset[i]
@ -41,30 +43,30 @@ def test_dataset_loader_random_paths(base_test_env):
def test_dpo_strategy_with_random_data(base_test_env): def test_dpo_strategy_with_random_data(base_test_env):
"""Test DPO strategy with randomized preference data""" """Test DPO strategy with randomized preference data"""
test_dir = base_test_env["test_dir"] test_dir = base_test_env["test_dir"]
# Create DPO-style data with memory mapping format # Create DPO-style data with memory mapping format
seq_length = np.random.randint(100, 200) seq_length = np.random.randint(100, 200)
dummy_data = { dummy_data = {
"chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], "chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], "rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"chosen_mask": [torch.ones(seq_length, dtype=torch.bool)], "chosen_mask": [torch.ones(seq_length, dtype=torch.bool)],
"rejected_mask": [torch.ones(seq_length, dtype=torch.bool)] "rejected_mask": [torch.ones(seq_length, dtype=torch.bool)],
} }
save_h5(test_dir, "dpo_data", dummy_data) save_h5(test_dir, "dpo_data", dummy_data)
# Load DPO dataset # Load DPO dataset
dpo_dataset = DatasetLoader.load( dpo_dataset = DatasetLoader.load(
train_type="dpo", train_type="dpo",
load_path=test_dir, load_path=test_dir,
window_size=64, window_size=64,
) )
assert dpo_dataset is not None assert dpo_dataset is not None
assert hasattr(dpo_dataset, 'fetcher') assert hasattr(dpo_dataset, "fetcher")
assert len(dpo_dataset) > 0 assert len(dpo_dataset) > 0
# Test that we can get DPO items without errors # Test that we can get DPO items without errors
for i in range(min(3, len(dpo_dataset))): for i in range(min(3, len(dpo_dataset))):
item = dpo_dataset[i] item = dpo_dataset[i]
@ -79,28 +81,28 @@ def test_dpo_strategy_with_random_data(base_test_env):
def test_sft_dataset_with_random_data(base_test_env): def test_sft_dataset_with_random_data(base_test_env):
"""Test SFT dataset with random data""" """Test SFT dataset with random data"""
test_dir = base_test_env["test_dir"] test_dir = base_test_env["test_dir"]
# Create SFT-style data with memory mapping format # Create SFT-style data with memory mapping format
seq_length = np.random.randint(100, 200) seq_length = np.random.randint(100, 200)
dummy_data = { dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)] "loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
} }
save_h5(test_dir, "sft_data", dummy_data) save_h5(test_dir, "sft_data", dummy_data)
# Load SFT dataset # Load SFT dataset
sft_dataset = DatasetLoader.load( sft_dataset = DatasetLoader.load(
train_type="sft", train_type="sft",
load_path=test_dir, load_path=test_dir,
window_size=64, window_size=64,
) )
assert sft_dataset is not None assert sft_dataset is not None
assert hasattr(sft_dataset, 'fetcher') assert hasattr(sft_dataset, "fetcher")
assert len(sft_dataset) > 0 assert len(sft_dataset) > 0
# Test that we can get SFT items without errors # Test that we can get SFT items without errors
for i in range(min(3, len(sft_dataset))): for i in range(min(3, len(sft_dataset))):
item = sft_dataset[i] item = sft_dataset[i]
@ -114,33 +116,30 @@ def test_sft_dataset_with_random_data(base_test_env):
def test_dataset_with_custom_stride(base_test_env): def test_dataset_with_custom_stride(base_test_env):
"""Test dataset with custom stride parameter""" """Test dataset with custom stride parameter"""
test_dir = base_test_env["test_dir"] test_dir = base_test_env["test_dir"]
# Create test data # Create test data
seq_length = 200 seq_length = 200
dummy_data = { dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
} }
save_h5(test_dir,"stride_test_data", dummy_data) save_h5(test_dir, "stride_test_data", dummy_data)
# Test with custom stride # Test with custom stride
custom_stride = 32 custom_stride = 32
dataset = DatasetLoader.load( dataset = DatasetLoader.load(
train_type="seq", train_type="seq", load_path=test_dir, window_size=64, stride=custom_stride
load_path=test_dir,
window_size=64,
stride=custom_stride
) )
assert dataset is not None assert dataset is not None
assert len(dataset) > 0 assert len(dataset) > 0
# With stride 32 and window 64 on 200 length data, we should get more samples # With stride 32 and window 64 on 200 length data, we should get more samples
# than with default stride (which equals window size) # than with default stride (which equals window size)
default_stride_dataset = DatasetLoader.load( default_stride_dataset = DatasetLoader.load(
train_type="seq", train_type="seq",
load_path=test_dir, load_path=test_dir,
window_size=64, window_size=64,
) )
assert len(dataset) > len(default_stride_dataset) assert len(dataset) > len(default_stride_dataset)

View File

@ -1,30 +1,32 @@
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.data import * from khaosz.data import *
def test_random_sampler_consistency(random_dataset): def test_random_sampler_consistency(random_dataset):
"""Test RandomSampler produces consistent results with same seed""" """Test RandomSampler produces consistent results with same seed"""
dataset = random_dataset dataset = random_dataset
# Create two samplers with same seed # Create two samplers with same seed
sampler1 = ResumableDistributedSampler(dataset, seed=42) sampler1 = ResumableDistributedSampler(dataset, seed=42)
sampler2 = ResumableDistributedSampler(dataset, seed=42) sampler2 = ResumableDistributedSampler(dataset, seed=42)
indices1 = list(iter(sampler1)) indices1 = list(iter(sampler1))
indices2 = list(iter(sampler2)) indices2 = list(iter(sampler2))
assert indices1 == indices2 assert indices1 == indices2
def test_random_sampler_different_seeds(random_dataset): def test_random_sampler_different_seeds(random_dataset):
"""Test RandomSampler produces different results with different seeds""" """Test RandomSampler produces different results with different seeds"""
dataset = random_dataset dataset = random_dataset
# Create two samplers with different seeds # Create two samplers with different seeds
sampler1 = ResumableDistributedSampler(dataset, seed=42) sampler1 = ResumableDistributedSampler(dataset, seed=42)
sampler2 = ResumableDistributedSampler(dataset, seed=123) sampler2 = ResumableDistributedSampler(dataset, seed=123)
indices1 = list(iter(sampler1)) indices1 = list(iter(sampler1))
indices2 = list(iter(sampler2)) indices2 = list(iter(sampler2))
# Very high probability they should be different # Very high probability they should be different
assert indices1 != indices2 assert indices1 != indices2
@ -33,20 +35,20 @@ def test_sampler_across_epochs(random_dataset):
"""Test sampler behavior across multiple epochs""" """Test sampler behavior across multiple epochs"""
dataset = random_dataset dataset = random_dataset
n = len(dataset) n = len(dataset)
sampler = ResumableDistributedSampler(dataset, seed=42) sampler = ResumableDistributedSampler(dataset, seed=42)
# Get indices for first epoch # Get indices for first epoch
epoch1_indices = list(iter(sampler)) epoch1_indices = list(iter(sampler))
assert len(epoch1_indices) == n assert len(epoch1_indices) == n
# Get indices for second epoch # Get indices for second epoch
epoch2_indices = list(iter(sampler)) epoch2_indices = list(iter(sampler))
assert len(epoch2_indices) == n assert len(epoch2_indices) == n
# Check that epochs have different order (should be random) # Check that epochs have different order (should be random)
assert epoch1_indices != epoch2_indices assert epoch1_indices != epoch2_indices
# Check that all indices are present in each epoch # Check that all indices are present in each epoch
assert set(epoch1_indices) == set(range(n)) assert set(epoch1_indices) == set(range(n))
assert set(epoch2_indices) == set(range(n)) assert set(epoch2_indices) == set(range(n))

View File

@ -12,6 +12,7 @@ from khaosz.data import *
from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore
from tokenizers import pre_tokenizers from tokenizers import pre_tokenizers
@pytest.fixture @pytest.fixture
def test_env(request: pytest.FixtureRequest): def test_env(request: pytest.FixtureRequest):
func_name = request.function.__name__ func_name = request.function.__name__
@ -19,7 +20,7 @@ def test_env(request: pytest.FixtureRequest):
config_path = os.path.join(test_dir, "config.json") config_path = os.path.join(test_dir, "config.json")
tokenizer_path = os.path.join(test_dir, "tokenizer.json") tokenizer_path = os.path.join(test_dir, "tokenizer.json")
model_path = os.path.join(test_dir, "model.safetensors") model_path = os.path.join(test_dir, "model.safetensors")
config = { config = {
"vocab_size": 1000, "vocab_size": 1000,
"dim": 128, "dim": 128,
@ -28,20 +29,20 @@ def test_env(request: pytest.FixtureRequest):
"dim_ffn": 256, "dim_ffn": 256,
"max_len": 64, "max_len": 64,
"n_layers": 2, "n_layers": 2,
"norm_eps": 1e-5 "norm_eps": 1e-5,
} }
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(config, f) json.dump(config, f)
tokenizer = BpeTokenizer() tokenizer = BpeTokenizer()
sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet()) sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1) tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
tokenizer.save(tokenizer_path) tokenizer.save(tokenizer_path)
transformer_config = ModelConfig().load(config_path) transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config) model = Transformer(transformer_config)
st.save_file(model.state_dict(), model_path) st.save_file(model.state_dict(), model_path)
yield { yield {
"test_dir": test_dir, "test_dir": test_dir,
"model": model, "model": model,
@ -51,47 +52,55 @@ def test_env(request: pytest.FixtureRequest):
shutil.rmtree(test_dir) shutil.rmtree(test_dir)
def test_model_parameter(test_env): def test_model_parameter(test_env):
save_dir = os.path.join(test_env["test_dir"], "save") save_dir = os.path.join(test_env["test_dir"], "save")
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"]) model_param = ModelParameter(
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
)
ModelParameter.save(model_param, save_dir) ModelParameter.save(model_param, save_dir)
assert os.path.exists(os.path.join(save_dir, "model.safetensors")) assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
assert os.path.exists(os.path.join(save_dir, "tokenizer.json")) assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))
assert os.path.exists(os.path.join(save_dir, "config.json")) assert os.path.exists(os.path.join(save_dir, "config.json"))
# transformer # transformer
def test_transformer(test_env): def test_transformer(test_env):
model = test_env["model"] model = test_env["model"]
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, input_ids = torch.randint(
(4, test_env["transformer_config"].max_len)) 0,
test_env["transformer_config"].vocab_size,
(4, test_env["transformer_config"].max_len),
)
output_logits = model(input_ids)["logits"] output_logits = model(input_ids)["logits"]
target_shape = (4, test_env["transformer_config"].max_len, test_env["transformer_config"].vocab_size) target_shape = (
4,
test_env["transformer_config"].max_len,
test_env["transformer_config"].vocab_size,
)
assert output_logits.shape == target_shape assert output_logits.shape == target_shape
# generator # generator
def test_embedding_encoder_core(test_env): def test_embedding_encoder_core(test_env):
parameter = ModelParameter( parameter = ModelParameter(
test_env["model"], test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
test_env["tokenizer"],
test_env["transformer_config"]
) )
encoder = EmbeddingEncoderCore(parameter) encoder = EmbeddingEncoderCore(parameter)
single_emb = encoder.encode("测试文本") single_emb = encoder.encode("测试文本")
assert isinstance(single_emb, torch.Tensor) assert isinstance(single_emb, torch.Tensor)
assert single_emb.shape[-1] == test_env["transformer_config"].dim assert single_emb.shape[-1] == test_env["transformer_config"].dim
batch_emb = encoder.encode(["测试1", "测试2"]) batch_emb = encoder.encode(["测试1", "测试2"])
assert isinstance(batch_emb, list) assert isinstance(batch_emb, list)
assert len(batch_emb) == 2 assert len(batch_emb) == 2
def test_generator_core(test_env): def test_generator_core(test_env):
parameter = ModelParameter( parameter = ModelParameter(
test_env["model"], test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
test_env["tokenizer"],
test_env["transformer_config"]
) )
generator = GeneratorCore(parameter) generator = GeneratorCore(parameter)
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10)) input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10))
@ -102,8 +111,8 @@ def test_generator_core(test_env):
top_p=0.95, top_p=0.95,
attn_mask=None, attn_mask=None,
kv_caches=None, kv_caches=None,
start_pos=0 start_pos=0,
) )
assert next_token_id.shape == (4, 1) assert next_token_id.shape == (4, 1)
assert cache_increase == 10 assert cache_increase == 10

View File

@ -13,7 +13,7 @@ def transformer_test_env():
"""创建Transformer测试专用环境""" """创建Transformer测试专用环境"""
test_dir = tempfile.mkdtemp(prefix="transformer_test_") test_dir = tempfile.mkdtemp(prefix="transformer_test_")
config_path = os.path.join(test_dir, "config.json") config_path = os.path.join(test_dir, "config.json")
config = { config = {
"vocab_size": 1000, "vocab_size": 1000,
"dim": 128, "dim": 128,
@ -22,18 +22,14 @@ def transformer_test_env():
"dim_ffn": 256, "dim_ffn": 256,
"max_len": 64, "max_len": 64,
"n_layers": 2, "n_layers": 2,
"norm_eps": 1e-5 "norm_eps": 1e-5,
} }
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(config, f) json.dump(config, f)
yield { yield {"test_dir": test_dir, "config_path": config_path, "config": config}
"test_dir": test_dir,
"config_path": config_path,
"config": config
}
if os.path.exists(test_dir): if os.path.exists(test_dir):
try: try:
for file in os.listdir(test_dir): for file in os.listdir(test_dir):
@ -46,74 +42,75 @@ def transformer_test_env():
def test_tie_weight_init(transformer_test_env): def test_tie_weight_init(transformer_test_env):
config_path = transformer_test_env["config_path"] config_path = transformer_test_env["config_path"]
config_data = transformer_test_env["config"].copy() config_data = transformer_test_env["config"].copy()
# case 1: tie weight # case 1: tie weight
config_data["tie_weight"] = True config_data["tie_weight"] = True
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig().load(config_path) config = ModelConfig().load(config_path)
model = Transformer(config) model = Transformer(config)
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr() assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
original_weight = model.embed_tokens.weight.clone() original_weight = model.embed_tokens.weight.clone()
model.embed_tokens.weight.data[0, 0] = 100.0 model.embed_tokens.weight.data[0, 0] = 100.0
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert not torch.equal(model.lm_head.weight, original_weight) assert not torch.equal(model.lm_head.weight, original_weight)
# case 2: not tie weight # case 2: not tie weight
config_data["tie_weight"] = False config_data["tie_weight"] = False
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig().load(config_path) config = ModelConfig().load(config_path)
model = Transformer(config) model = Transformer(config)
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr() assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
original_weight = model.embed_tokens.weight.clone() original_weight = model.embed_tokens.weight.clone()
model.embed_tokens.weight.data[0, 0] = 100.0 model.embed_tokens.weight.data[0, 0] = 100.0
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert not torch.equal(model.lm_head.weight, original_weight) assert not torch.equal(model.lm_head.weight, original_weight)
def test_model_save_load_with_tie_weight(transformer_test_env): def test_model_save_load_with_tie_weight(transformer_test_env):
test_dir = transformer_test_env["test_dir"] test_dir = transformer_test_env["test_dir"]
model_path = os.path.join(test_dir, "model.safetensors") model_path = os.path.join(test_dir, "model.safetensors")
config_data = transformer_test_env["config"].copy() config_data = transformer_test_env["config"].copy()
# case 1: tie weight # case 1: tie weight
config_data["tie_weight"] = True config_data["tie_weight"] = True
config_path = os.path.join(test_dir, "config.json") config_path = os.path.join(test_dir, "config.json")
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig().load(config_path) config = ModelConfig().load(config_path)
original_model = Transformer(config) original_model = Transformer(config)
st.save_file(original_model.state_dict(), model_path) st.save_file(original_model.state_dict(), model_path)
loaded_config = ModelConfig().load(config_path) loaded_config = ModelConfig().load(config_path)
model = Transformer(loaded_config) model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path)) model.load_state_dict(st.load_file(model_path))
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr() assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
assert "lm_head.weight" not in model.state_dict() assert "lm_head.weight" not in model.state_dict()
# case 2: not tie weight (form tie-weight state dict load) # case 2: not tie weight (form tie-weight state dict load)
config_data["tie_weight"] = False config_data["tie_weight"] = False
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
loaded_config = ModelConfig().load(config_path) loaded_config = ModelConfig().load(config_path)
model = Transformer(loaded_config) model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path)) model.load_state_dict(st.load_file(model_path))
@ -121,4 +118,3 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr() assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
assert "lm_head.weight" in model.state_dict() assert "lm_head.weight" in model.state_dict()

View File

@ -1,16 +1,14 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from khaosz.parallel import ( from khaosz.parallel import get_rank, only_on_rank, spawn_parallel_fn
get_rank,
only_on_rank,
spawn_parallel_fn
)
@only_on_rank(0) @only_on_rank(0)
def _test_only_on_rank_helper(): def _test_only_on_rank_helper():
return True return True
def only_on_rank(): def only_on_rank():
result = _test_only_on_rank_helper() result = _test_only_on_rank_helper()
if get_rank() == 0: if get_rank() == 0:
@ -18,22 +16,17 @@ def only_on_rank():
else: else:
assert result is None assert result is None
def all_reduce(): def all_reduce():
x = torch.tensor([get_rank()], dtype=torch.int) x = torch.tensor([get_rank()], dtype=torch.int)
dist.all_reduce(x, op=dist.ReduceOp.SUM) dist.all_reduce(x, op=dist.ReduceOp.SUM)
expected_sum = sum(range(dist.get_world_size())) expected_sum = sum(range(dist.get_world_size()))
assert x.item() == expected_sum assert x.item() == expected_sum
def test_spawn_only_on_rank(): def test_spawn_only_on_rank():
spawn_parallel_fn( spawn_parallel_fn(only_on_rank, world_size=2, backend="gloo")
only_on_rank,
world_size=2,
backend="gloo"
)
def test_spawn_all_reduce(): def test_spawn_all_reduce():
spawn_parallel_fn( spawn_parallel_fn(all_reduce, world_size=2, backend="gloo")
all_reduce,
world_size=2,
backend="gloo"
)

View File

@ -3,57 +3,48 @@ import torch
from khaosz.config import * from khaosz.config import *
from khaosz.trainer import * from khaosz.trainer import *
def test_callback_integration(base_test_env, random_dataset): def test_callback_integration(base_test_env, random_dataset):
"""Test that all callbacks are properly integrated""" """Test that all callbacks are properly integrated"""
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
warmup_steps=10,
total_steps=20
)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
model=base_test_env["model"], model=base_test_env["model"],
strategy='seq', strategy="seq",
dataset=random_dataset, dataset=random_dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
checkpoint_dir=base_test_env["test_dir"], ckpt_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=2, batch_size=2,
checkpoint_interval=3, ckpt_interval=3,
accumulation_steps=1, accumulation_steps=1,
max_grad_norm=1.0, max_grad_norm=1.0,
random_seed=42, random_seed=42,
device_type=base_test_env["device"] device_type=base_test_env["device"],
) )
# Create custom callbacks to track calls # Create custom callbacks to track calls
callback_calls = [] callback_calls = []
class TrackingCallback(TrainCallback): class TrackingCallback(TrainCallback):
def on_train_begin(self, context): def on_train_begin(self, context):
callback_calls.append('on_train_begin') callback_calls.append("on_train_begin")
def on_batch_end(self, context): def on_batch_end(self, context):
callback_calls.append('on_batch_end') callback_calls.append("on_batch_end")
def on_epoch_end(self, context): def on_epoch_end(self, context):
callback_calls.append('on_epoch_end') callback_calls.append("on_epoch_end")
trainer = Trainer(train_config, callbacks=[TrackingCallback()])
trainer = Trainer(
train_config,
callbacks=[TrackingCallback()]
)
trainer.train() trainer.train()
# Verify callbacks were called # Verify callbacks were called
assert 'on_train_begin' in callback_calls assert "on_train_begin" in callback_calls
assert 'on_batch_end' in callback_calls assert "on_batch_end" in callback_calls
assert 'on_epoch_end' in callback_calls assert "on_epoch_end" in callback_calls

View File

@ -5,31 +5,32 @@ from khaosz.config import *
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.data.serialization import Checkpoint from khaosz.data.serialization import Checkpoint
def test_early_stopping_simulation(base_test_env, early_stopping_dataset): def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
"""Simulate early stopping behavior""" """Simulate early stopping behavior"""
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
strategy="seq", strategy="seq",
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
model=base_test_env["model"], model=base_test_env["model"],
dataset=early_stopping_dataset, dataset=early_stopping_dataset,
checkpoint_dir=base_test_env["test_dir"], ckpt_dir=base_test_env["test_dir"],
n_epoch=2, n_epoch=2,
batch_size=2, batch_size=2,
checkpoint_interval=1, ckpt_interval=1,
accumulation_steps=2, accumulation_steps=2,
random_seed=np.random.randint(1e4), random_seed=np.random.randint(1e4),
device_type=base_test_env["device"] device_type=base_test_env["device"],
) )
trainer = Trainer(train_config) trainer = Trainer(train_config)
# Should handle early stopping gracefully # Should handle early stopping gracefully
checkpoint = None checkpoint = None
try: try:
@ -37,11 +38,11 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
except Exception: except Exception:
# Handle any exceptions # Handle any exceptions
pass pass
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2") load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
checkpoint = Checkpoint.load(load_dir) checkpoint = Checkpoint.load(load_dir)
trainer.train(checkpoint) trainer.train(checkpoint)
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10") load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
checkpoint = Checkpoint.load(load_dir) checkpoint = Checkpoint.load(load_dir)
assert checkpoint.iteration == 10 assert checkpoint.iteration == 10

View File

@ -9,39 +9,41 @@ from khaosz.data.dataset import *
def test_schedule_factory_random_configs(): def test_schedule_factory_random_configs():
"""Test scheduler factory with random configurations""" """Test scheduler factory with random configurations"""
# Create a simple model and optimizer for testing # Create a simple model and optimizer for testing
model = torch.nn.Linear(10, 2) model = torch.nn.Linear(10, 2)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
# Test multiple random configurations # Test multiple random configurations
for _ in range(5): # Test 5 random configurations for _ in range(5): # Test 5 random configurations
schedule_configs = [ schedule_configs = [
CosineScheduleConfig( CosineScheduleConfig(
warmup_steps=np.random.randint(50, 200), warmup_steps=np.random.randint(50, 200),
total_steps=np.random.randint(1000, 5000), total_steps=np.random.randint(1000, 5000),
min_rate=np.random.uniform(0.01, 0.1) min_rate=np.random.uniform(0.01, 0.1),
), ),
SGDRScheduleConfig( SGDRScheduleConfig(
warmup_steps=np.random.randint(50, 200), warmup_steps=np.random.randint(50, 200),
cycle_length=np.random.randint(500, 2000), cycle_length=np.random.randint(500, 2000),
t_mult=np.random.randint(1, 3), t_mult=np.random.randint(1, 3),
min_rate=np.random.uniform(0.01, 0.1) min_rate=np.random.uniform(0.01, 0.1),
) ),
] ]
for config in schedule_configs: for config in schedule_configs:
# Validate configuration # Validate configuration
config.validate() config.validate()
# Create scheduler using factory # Create scheduler using factory
scheduler = SchedulerFactory.load(optimizer, config) scheduler = SchedulerFactory.load(optimizer, config)
# Verify scheduler type # Verify scheduler type
if isinstance(config, CosineScheduleConfig): if isinstance(config, CosineScheduleConfig):
assert isinstance(scheduler, CosineScheduler) assert isinstance(scheduler, CosineScheduler)
assert scheduler.warmup_steps == config.warmup_steps assert scheduler.warmup_steps == config.warmup_steps
assert scheduler.lr_decay_steps == config.total_steps - config.warmup_steps assert (
scheduler.lr_decay_steps == config.total_steps - config.warmup_steps
)
assert scheduler.min_rate == config.min_rate assert scheduler.min_rate == config.min_rate
elif isinstance(config, SGDRScheduleConfig): elif isinstance(config, SGDRScheduleConfig):
assert isinstance(scheduler, SGDRScheduler) assert isinstance(scheduler, SGDRScheduler)
@ -49,17 +51,17 @@ def test_schedule_factory_random_configs():
assert scheduler.cycle_length == config.cycle_length assert scheduler.cycle_length == config.cycle_length
assert scheduler.t_mult == config.t_mult assert scheduler.t_mult == config.t_mult
assert scheduler.min_rate == config.min_rate assert scheduler.min_rate == config.min_rate
# Test scheduler state dict functionality # Test scheduler state dict functionality
state_dict = scheduler.state_dict() state_dict = scheduler.state_dict()
assert 'warmup_steps' in state_dict assert "warmup_steps" in state_dict
assert 'min_rate' in state_dict assert "min_rate" in state_dict
# Test scheduler step functionality # Test scheduler step functionality
initial_lr = scheduler.get_last_lr() initial_lr = scheduler.get_last_lr()
scheduler.step() scheduler.step()
new_lr = scheduler.get_last_lr() new_lr = scheduler.get_last_lr()
# Learning rate should change after step, or if it's the first step, # Learning rate should change after step, or if it's the first step,
# the epoch counter should increment # the epoch counter should increment
assert initial_lr != new_lr or scheduler.last_epoch > -1 assert initial_lr != new_lr or scheduler.last_epoch > -1
@ -67,10 +69,10 @@ def test_schedule_factory_random_configs():
def test_schedule_factory_edge_cases(): def test_schedule_factory_edge_cases():
"""Test scheduler factory with edge cases and boundary conditions""" """Test scheduler factory with edge cases and boundary conditions"""
model = torch.nn.Linear(10, 2) model = torch.nn.Linear(10, 2)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
# Test edge cases for CosineScheduleConfig # Test edge cases for CosineScheduleConfig
edge_cases = [ edge_cases = [
# Minimal warmup and steps # Minimal warmup and steps
@ -80,12 +82,12 @@ def test_schedule_factory_edge_cases():
# Zero min_rate (edge case) # Zero min_rate (edge case)
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.0), CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.0),
] ]
for config in edge_cases: for config in edge_cases:
config.validate() config.validate()
scheduler = SchedulerFactory.load(optimizer, config) scheduler = SchedulerFactory.load(optimizer, config)
assert scheduler is not None assert scheduler is not None
# Test multiple steps # Test multiple steps
for _ in range(10): for _ in range(10):
scheduler.step() scheduler.step()
@ -93,7 +95,7 @@ def test_schedule_factory_edge_cases():
def test_schedule_factory_invalid_configs(): def test_schedule_factory_invalid_configs():
"""Test scheduler factory with invalid configurations""" """Test scheduler factory with invalid configurations"""
# Test invalid configurations that should raise errors # Test invalid configurations that should raise errors
invalid_configs = [ invalid_configs = [
# Negative warmup steps # Negative warmup steps
@ -104,7 +106,7 @@ def test_schedule_factory_invalid_configs():
{"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1}, {"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1},
{"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1}, {"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1},
] ]
for kwargs in invalid_configs: for kwargs in invalid_configs:
with pytest.raises(ValueError): with pytest.raises(ValueError):
config = CosineScheduleConfig(**kwargs) config = CosineScheduleConfig(**kwargs)
@ -113,24 +115,24 @@ def test_schedule_factory_invalid_configs():
def test_schedule_factory_state_persistence(): def test_schedule_factory_state_persistence():
"""Test scheduler state persistence (save/load)""" """Test scheduler state persistence (save/load)"""
model = torch.nn.Linear(10, 2) model = torch.nn.Linear(10, 2)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1) config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1)
scheduler = SchedulerFactory.load(optimizer, config) scheduler = SchedulerFactory.load(optimizer, config)
# Take a few steps # Take a few steps
for _ in range(5): for _ in range(5):
scheduler.step() scheduler.step()
# Save state # Save state
state_dict = scheduler.state_dict() state_dict = scheduler.state_dict()
# Create new scheduler and load state # Create new scheduler and load state
new_scheduler = SchedulerFactory.load(optimizer, config) new_scheduler = SchedulerFactory.load(optimizer, config)
new_scheduler.load_state_dict(state_dict) new_scheduler.load_state_dict(state_dict)
# Verify states match # Verify states match
assert scheduler.last_epoch == new_scheduler.last_epoch assert scheduler.last_epoch == new_scheduler.last_epoch
assert scheduler.get_last_lr() == new_scheduler.get_last_lr() assert scheduler.get_last_lr() == new_scheduler.get_last_lr()

View File

@ -6,100 +6,94 @@ from khaosz.config import *
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.data.dataset import * from khaosz.data.dataset import *
def test_different_batch_sizes(base_test_env, random_dataset): def test_different_batch_sizes(base_test_env, random_dataset):
"""Test training with different batch sizes""" """Test training with different batch sizes"""
batch_sizes = [1, 2, 4, 8] batch_sizes = [1, 2, 4, 8]
for batch_size in batch_sizes: for batch_size in batch_sizes:
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
warmup_steps=10,
total_steps=20
)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
strategy="seq", strategy="seq",
model=base_test_env["model"], model=base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
checkpoint_dir=base_test_env["test_dir"], ckpt_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=batch_size, batch_size=batch_size,
checkpoint_interval=5, ckpt_interval=5,
accumulation_steps=1, accumulation_steps=1,
max_grad_norm=1.0, max_grad_norm=1.0,
random_seed=np.random.randint(1000), random_seed=np.random.randint(1000),
device_type=base_test_env["device"] device_type=base_test_env["device"],
) )
assert train_config.batch_size == batch_size assert train_config.batch_size == batch_size
def test_gradient_accumulation(base_test_env, random_dataset): def test_gradient_accumulation(base_test_env, random_dataset):
"""Test training with different gradient accumulation steps""" """Test training with different gradient accumulation steps"""
accumulation_steps_list = [1, 2, 4] accumulation_steps_list = [1, 2, 4]
for accumulation_steps in accumulation_steps_list: for accumulation_steps in accumulation_steps_list:
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
warmup_steps=10,
total_steps=20
)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
strategy="seq", strategy="seq",
model=base_test_env["model"], model=base_test_env["model"],
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
dataset=random_dataset, dataset=random_dataset,
checkpoint_dir=base_test_env["test_dir"], ckpt_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=2, batch_size=2,
checkpoint_interval=10, ckpt_interval=10,
accumulation_steps=accumulation_steps, accumulation_steps=accumulation_steps,
max_grad_norm=1.0, max_grad_norm=1.0,
random_seed=42, random_seed=42,
device_type=base_test_env["device"] device_type=base_test_env["device"],
) )
trainer = Trainer(train_config) trainer = Trainer(train_config)
trainer.train() trainer.train()
assert train_config.accumulation_steps == accumulation_steps assert train_config.accumulation_steps == accumulation_steps
def test_memory_efficient_training(base_test_env, random_dataset): def test_memory_efficient_training(base_test_env, random_dataset):
"""Test training with memory-efficient configurations""" """Test training with memory-efficient configurations"""
# Test with smaller batch sizes and gradient checkpointing # Test with smaller batch sizes and gradient checkpointing
small_batch_configs = [ small_batch_configs = [
{"batch_size": 1, "accumulation_steps": 8}, {"batch_size": 1, "accumulation_steps": 8},
{"batch_size": 2, "accumulation_steps": 4}, {"batch_size": 2, "accumulation_steps": 4},
{"batch_size": 4, "accumulation_steps": 2} {"batch_size": 4, "accumulation_steps": 2},
] ]
for config in small_batch_configs: for config in small_batch_configs:
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
warmup_steps=10,
total_steps=20
)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
strategy="seq", strategy="seq",
model=base_test_env["model"], model=base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
checkpoint_dir=base_test_env["test_dir"], ckpt_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=config["batch_size"], batch_size=config["batch_size"],
checkpoint_interval=5, ckpt_interval=5,
accumulation_steps=config["accumulation_steps"], accumulation_steps=config["accumulation_steps"],
max_grad_norm=1.0, max_grad_norm=1.0,
random_seed=42, random_seed=42,
device_type=base_test_env["device"] device_type=base_test_env["device"],
) )
assert train_config.accumulation_steps == config["accumulation_steps"] assert train_config.accumulation_steps == config["accumulation_steps"]

View File

@ -17,41 +17,47 @@ class GenerationBenchmark:
self, self,
config: ModelConfig, config: ModelConfig,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.float16 dtype: torch.dtype = torch.float16,
): ):
self.config = config self.config = config
self.device = device self.device = device
self.dtype = dtype self.dtype = dtype
self.model = Transformer(config).to(device=device, dtype=dtype) self.model = Transformer(config).to(device=device, dtype=dtype)
self.model.eval() self.model.eval()
def _initialize_kv_cache(self, batch_size: int) -> list: def _initialize_kv_cache(self, batch_size: int) -> list:
"""初始化KV缓存""" """初始化KV缓存"""
config = self.config config = self.config
shape = (batch_size, config.max_len, config.n_layers, config.n_kv_heads, config.dim // config.n_heads) shape = (
batch_size,
config.max_len,
config.n_layers,
config.n_kv_heads,
config.dim // config.n_heads,
)
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
return (k_cache, v_cache) return (k_cache, v_cache)
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int): def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
prompt_ids = torch.randint( prompt_ids = torch.randint(
low=0, low=0,
high=self.config.vocab_size, high=self.config.vocab_size,
size=(batch_size, prompt_length), size=(batch_size, prompt_length),
device=self.device, device=self.device,
dtype=torch.long dtype=torch.long,
) )
gen_ids = torch.randint( gen_ids = torch.randint(
low=0, low=0,
high=self.config.vocab_size, high=self.config.vocab_size,
size=(batch_size, total_length - prompt_length), size=(batch_size, total_length - prompt_length),
device=self.device, device=self.device,
dtype=torch.long dtype=torch.long,
) )
return prompt_ids, gen_ids return prompt_ids, gen_ids
@torch.inference_mode() @torch.inference_mode()
def run_prefill_benchmark( def run_prefill_benchmark(
self, self,
@ -59,32 +65,38 @@ class GenerationBenchmark:
prompt_length: int = 512, prompt_length: int = 512,
num_trials: int = 10, num_trials: int = 10,
) -> BenchmarkResult: ) -> BenchmarkResult:
for _ in range(3): for _ in range(3):
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length) prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length
)
_ = self.model(prompt_ids) _ = self.model(prompt_ids)
torch.cuda.synchronize() torch.cuda.synchronize()
total_time = 0.0 total_time = 0.0
total_tokens = batch_size * prompt_length * num_trials total_tokens = batch_size * prompt_length * num_trials
for trial in range(num_trials): for trial in range(num_trials):
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length) prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length
)
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
start_event.record() start_event.record()
_ = self.model(prompt_ids) _ = self.model(prompt_ids)
end_event.record() end_event.record()
torch.cuda.synchronize() torch.cuda.synchronize()
trial_time = start_event.elapsed_time(end_event) / 1000 trial_time = start_event.elapsed_time(end_event) / 1000
total_time += trial_time total_time += trial_time
print(f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s " print(
f"({prompt_length / trial_time:.1f} tokens/s)") f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
f"({prompt_length / trial_time:.1f} tokens/s)"
)
return BenchmarkResult( return BenchmarkResult(
total_tokens=total_tokens, total_tokens=total_tokens,
total_time=total_time, total_time=total_time,
@ -95,9 +107,9 @@ class GenerationBenchmark:
"prompt_length": prompt_length, "prompt_length": prompt_length,
"dtype": self.dtype, "dtype": self.dtype,
"device": self.device, "device": self.device,
} },
) )
@torch.inference_mode() @torch.inference_mode()
def run_decoding_benchmark( def run_decoding_benchmark(
self, self,
@ -106,39 +118,43 @@ class GenerationBenchmark:
gen_length: int = 128, gen_length: int = 128,
num_trials: int = 5, num_trials: int = 5,
) -> BenchmarkResult: ) -> BenchmarkResult:
total_time = 0.0 total_time = 0.0
total_tokens = batch_size * gen_length * num_trials total_tokens = batch_size * gen_length * num_trials
for trial in range(num_trials): for trial in range(num_trials):
prompt_ids, gen_ids = self._prepare_inputs(
prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length) batch_size, prompt_length, prompt_length + gen_length
)
kv_cache = self._initialize_kv_cache(batch_size) kv_cache = self._initialize_kv_cache(batch_size)
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0) _ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
start_event.record() start_event.record()
current_pos = prompt_length current_pos = prompt_length
for i in range(gen_length): for i in range(gen_length):
input_token = gen_ids[:, i:i+1] input_token = gen_ids[:, i : i + 1]
_ = self.model(input_token, persistent_key_values=kv_cache, start_pos=current_pos) _ = self.model(
input_token, persistent_key_values=kv_cache, start_pos=current_pos
)
current_pos += 1 current_pos += 1
end_event.record() end_event.record()
torch.cuda.synchronize() torch.cuda.synchronize()
trial_time = start_event.elapsed_time(end_event) / 1000 trial_time = start_event.elapsed_time(end_event) / 1000
total_time += trial_time total_time += trial_time
print(f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
f"({gen_length / trial_time:.1f} tokens/s)")
print(
f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
f"({gen_length / trial_time:.1f} tokens/s)"
)
return BenchmarkResult( return BenchmarkResult(
total_tokens=total_tokens, total_tokens=total_tokens,
total_time=total_time, total_time=total_time,
@ -150,24 +166,28 @@ class GenerationBenchmark:
"gen_length": gen_length, "gen_length": gen_length,
"dtype": self.dtype, "dtype": self.dtype,
"device": self.device, "device": self.device,
} },
) )
def print_benchmark_result(result: BenchmarkResult): def print_benchmark_result(result: BenchmarkResult):
"""打印基准测试结果""" """打印基准测试结果"""
benchmark_type = result.metadata["benchmark_type"] benchmark_type = result.metadata["benchmark_type"]
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}") print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
print(f"Total Tokens Processed: {result.total_tokens:,}") print(f"Total Tokens Processed: {result.total_tokens:,}")
print(f"Time Consumed: {result.total_time:.3f}s") print(f"Time Consumed: {result.total_time:.3f}s")
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s") print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
if benchmark_type == "prefill": if benchmark_type == "prefill":
print(f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}") print(
f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}"
)
elif benchmark_type == "decoding": elif benchmark_type == "decoding":
print(f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}") print(
f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}"
)
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}") print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
print("-" * 80) print("-" * 80)
@ -183,16 +203,19 @@ if __name__ == "__main__":
n_layers=24, n_layers=24,
norm_eps=1e-5, norm_eps=1e-5,
) )
benchmark = GenerationBenchmark(config) benchmark = GenerationBenchmark(config)
print("=" * 80) print("=" * 80)
print("Running Transformer Generation Benchmark") print("Running Transformer Generation Benchmark")
print("=" * 80) print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark(batch_size=4, prompt_length=512, num_trials=5) prefill_result = benchmark.run_prefill_benchmark(
batch_size=4, prompt_length=512, num_trials=5
)
print_benchmark_result(prefill_result) print_benchmark_result(prefill_result)
gen_result = benchmark.run_decoding_benchmark(batch_size=4, prompt_length=512, gen_length=128, num_trials=5) gen_result = benchmark.run_decoding_benchmark(
batch_size=4, prompt_length=512, gen_length=128, num_trials=5
)
print_benchmark_result(gen_result) print_benchmark_result(gen_result)

View File

@ -21,10 +21,10 @@ def processor(
with disable_random_init(): with disable_random_init():
param = ModelParameter.load(model_dir) param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16) param.to(device="cuda", dtype=torch.bfloat16)
generator = BatchGenerator(param) generator = BatchGenerator(param)
with open(input_json_file, "r", encoding='utf-8') as f: with open(input_json_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f] input_data = [json.loads(line) for line in f]
queries = [item[question_key] for item in input_data] queries = [item[question_key] for item in input_data]
@ -41,26 +41,62 @@ def processor(
responses = generator.generate(request) responses = generator.generate(request)
with open(output_json_file, "w", encoding='utf-8') as f: with open(output_json_file, "w", encoding="utf-8") as f:
for query, response in zip(queries, responses): for query, response in zip(queries, responses):
output_item = {question_key: query, response_key: response} output_item = {question_key: query, response_key: response}
f.write(json.dumps(output_item, ensure_ascii=False) + '\n') f.write(json.dumps(output_item, ensure_ascii=False) + "\n")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.") parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.") parser.add_argument(
parser.add_argument("--input_json_file", type=str, required=True, help="Path to the input JSONL file.") "--model_dir", type=str, required=True, help="Path to the model directory."
parser.add_argument("--output_json_file", type=str, required=True, help="Path to the output JSONL file.") )
parser.add_argument("--question_key", type=str, default="question", help="Key for the question in the input JSON.") parser.add_argument(
parser.add_argument("--response_key", type=str, default="response", help="Key for the response in the output JSON.") "--input_json_file",
parser.add_argument("--temperature", type=float, default=0.60, help="Temperature for generating responses.") type=str,
parser.add_argument("--top_k", type=int, default=30, help="Top-k value for generating responses.") required=True,
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.") help="Path to the input JSONL file.",
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.") )
parser.add_argument(
"--output_json_file",
type=str,
required=True,
help="Path to the output JSONL file.",
)
parser.add_argument(
"--question_key",
type=str,
default="question",
help="Key for the question in the input JSON.",
)
parser.add_argument(
"--response_key",
type=str,
default="response",
help="Key for the response in the output JSON.",
)
parser.add_argument(
"--temperature",
type=float,
default=0.60,
help="Temperature for generating responses.",
)
parser.add_argument(
"--top_k", type=int, default=30, help="Top-k value for generating responses."
)
parser.add_argument(
"--top_p",
type=float,
default=0.95,
help="Top-p value for generating responses.",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for generating responses."
)
args = parser.parse_args() args = parser.parse_args()
with torch.inference_mode(): with torch.inference_mode():
processor(**vars(args)) processor(**vars(args))

View File

@ -11,89 +11,99 @@ from khaosz.inference.core import disable_random_init
def compute_perplexity( def compute_perplexity(
model: nn.Module, model: nn.Module,
input_ids: Tensor, input_ids: Tensor,
input_mask: Tensor, input_mask: Tensor,
) -> Tensor: ) -> Tensor:
""" """
Compute the perplexity of a batch of input sequences, Compute the perplexity of a batch of input sequences,
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))). where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
""" """
output = model(input_ids, input_mask) output = model(input_ids, input_mask)
logits = output["logits"] logits = output["logits"]
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size] shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1] shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1] shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
loss = F.cross_entropy( loss = F.cross_entropy(
shifted_logits.flatten(0, 1), shifted_logits.flatten(0, 1), shifted_input_ids.flatten(0, 1), reduction="none"
shifted_input_ids.flatten(0, 1),
reduction='none'
) )
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1] loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
loss = loss * shifted_mask loss = loss * shifted_mask
sentence_loss = (loss).sum(dim=1) / shifted_mask.sum(dim=1) sentence_loss = (loss).sum(dim=1) / shifted_mask.sum(dim=1)
perplexity = torch.exp(sentence_loss) # [batch_size] perplexity = torch.exp(sentence_loss) # [batch_size]
return perplexity return perplexity
def process_file( def process_file(
model_dir: str, model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
input_file: str,
output_file: str,
batch_size: int,
text_key: str
): ):
with disable_random_init(): with disable_random_init():
param = ModelParameter.load(model_dir) param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16) param.to(device="cuda", dtype=torch.bfloat16)
model = param.model model = param.model
tokenizer = param.tokenizer tokenizer = param.tokenizer
with open(input_file, "r", encoding='utf-8') as f: with open(input_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f] input_data = [json.loads(line) for line in f]
texts = [item[text_key] for item in input_data] texts = [item[text_key] for item in input_data]
encoded_texts = [tokenizer.encode(text) for text in texts] encoded_texts = [tokenizer.encode(text) for text in texts]
output_data = [] output_data = []
for i in tqdm(range(0, len(encoded_texts), batch_size), desc="Computing perplexity"): for i in tqdm(
batch_encoded = encoded_texts[i:i + batch_size] range(0, len(encoded_texts), batch_size), desc="Computing perplexity"
batch_texts = texts[i:i + batch_size] ):
batch_encoded = encoded_texts[i : i + batch_size]
batch_texts = texts[i : i + batch_size]
max_len = max(len(seq) for seq in batch_encoded) max_len = max(len(seq) for seq in batch_encoded)
padded_ids = [] padded_ids = []
masks = [] masks = []
for seq in batch_encoded: for seq in batch_encoded:
pad_len = max_len - len(seq) pad_len = max_len - len(seq)
padded_seq = [tokenizer.pad_id] * pad_len + seq padded_seq = [tokenizer.pad_id] * pad_len + seq
mask = [False] * pad_len + [True] * len(seq) mask = [False] * pad_len + [True] * len(seq)
padded_ids.append(padded_seq) padded_ids.append(padded_seq)
masks.append(mask) masks.append(mask)
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long) input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool) input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
perplexity = compute_perplexity(model, input_ids, input_mask) perplexity = compute_perplexity(model, input_ids, input_mask)
for text, ppl in zip(batch_texts, perplexity): for text, ppl in zip(batch_texts, perplexity):
output_data.append({text_key: text, "ppl": float(ppl.item())}) output_data.append({text_key: text, "ppl": float(ppl.item())})
with open(output_file, "w", encoding='utf-8') as f: with open(output_file, "w", encoding="utf-8") as f:
for item in output_data: for item in output_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n') f.write(json.dumps(item, ensure_ascii=False) + "\n")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.") parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.") parser.add_argument(
parser.add_argument("--input_file", type=str, required=True, help="Path to the input file.") "--model_dir", type=str, required=True, help="Path to the model directory."
parser.add_argument("--output_file", type=str, required=True, help="Path to the output file.") )
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation.") parser.add_argument(
parser.add_argument("--text_key", type=str, default="text", help="Key for the text field in the input data.") "--input_file", type=str, required=True, help="Path to the input file."
)
parser.add_argument(
"--output_file", type=str, required=True, help="Path to the output file."
)
parser.add_argument(
"--batch_size", type=int, default=4, help="Batch size for evaluation."
)
parser.add_argument(
"--text_key",
type=str,
default="text",
help="Key for the text field in the input data.",
)
args = parser.parse_args() args = parser.parse_args()
with torch.inference_mode(): with torch.inference_mode():

View File

@ -15,40 +15,130 @@ from khaosz.parallel import get_rank
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train the Transformer model.") parser = argparse.ArgumentParser(description="Train the Transformer model.")
parser.add_argument("--train_type", type=str, required=True, choices=["seq", "sft", "dpo"], help="Train type.") parser.add_argument(
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.") "--train_type",
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.") type=str,
required=True,
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.") choices=["seq", "sft", "dpo"],
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.") help="Train type.",
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.") )
parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.") parser.add_argument(
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.") "--data_root_path",
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.") type=str,
parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.") required=True,
parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.") help="Path to the root directory of the dataset.",
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.") )
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") parser.add_argument(
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.") "--param_path",
parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory") type=str,
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.") required=True,
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.") help="Path to the model parameters or resume checkpoint.",
)
parser.add_argument(
"--n_epoch", type=int, default=1, help="Number of epochs to train."
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for training."
)
parser.add_argument(
"--accumulation_steps",
type=int,
default=1,
help="Number of iterations between each optimizer step.",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=1000,
help="Number of iters between warnings.",
)
parser.add_argument(
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
)
parser.add_argument(
"--max_grad_norm",
type=float,
default=1.0,
help="Max gradient norm for clipping.",
)
parser.add_argument(
"--adamw_beta1",
type=float,
default=0.9,
help="Beta values for AdamW optimizer.",
)
parser.add_argument(
"--adamw_beta2",
type=float,
default=0.95,
help="Beta values for AdamW optimizer.",
)
parser.add_argument(
"--adamw_weight_decay",
type=float,
default=0.01,
help="Weight decay for AdamW optimizer.",
)
parser.add_argument(
"--random_seed", type=int, default=3407, help="Random seed for reproducibility."
)
parser.add_argument(
"--num_workers", type=int, default=4, help="Number of workers for data loading."
)
parser.add_argument(
"--no_pin_memory",
action="store_false",
dest="pin_memory",
help="Disable pin memory",
)
parser.add_argument(
"--window_size",
type=int,
default=None,
help="the max length of the input sequence.",
)
parser.add_argument(
"--stride", type=int, default=None, help="the step size of the input sequence."
)
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
parser.add_argument("--label_smoothing", type=float, default=0.1, help="cross_entropy function label smoothing parameter") parser.add_argument(
"--label_smoothing",
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.") type=float,
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.") default=0.1,
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.") help="cross_entropy function label smoothing parameter",
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.") )
parser.add_argument(
"--ckpt_interval",
type=int,
default=5000,
help="Number of iters between checkpoints.",
)
parser.add_argument(
"--ckpt_dir",
type=str,
default="checkpoint",
help="Directory to save checkpoints.",
)
parser.add_argument(
"--start_epoch", type=int, default=0, help="Start epoch for training."
)
parser.add_argument(
"--start_batch", type=int, default=0, help="Start batch for training."
)
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.") parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
parser.add_argument("--device_type", type=str, default="cuda", help="Device type to use.") parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use."
)
args = parser.parse_args() args = parser.parse_args()
return args return args
def ddp_wrap(model: nn.Module): def ddp_wrap(model: nn.Module):
local_rank = get_rank() local_rank = get_rank()
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16) model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
@ -56,16 +146,21 @@ def ddp_wrap(model: nn.Module):
model, model,
device_ids=[local_rank], device_ids=[local_rank],
output_device=local_rank, output_device=local_rank,
find_unused_parameters=False find_unused_parameters=False,
) )
return ddp_model return ddp_model
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer: def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), **kwargs) return optim.AdamW(model.parameters(), **kwargs)
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
def create_scheduler(
optimizer: optim.Optimizer, **kwargs
) -> optim.lr_scheduler.LRScheduler:
return SchedulerFactory.load(optimizer, **kwargs) return SchedulerFactory.load(optimizer, **kwargs)
def prepare_checkpoint(model: nn.Module) -> dict: def prepare_checkpoint(model: nn.Module) -> dict:
return model.module.state_dict() return model.module.state_dict()
@ -81,8 +176,8 @@ def train(
start_batch: int, start_batch: int,
accumulation_steps: int, accumulation_steps: int,
warmup_steps: int, warmup_steps: int,
checkpoint_interval: int, ckpt_interval: int,
checkpoint_dir: str, ckpt_dir: str,
dpo_beta: float, dpo_beta: float,
adamw_beta1: float, adamw_beta1: float,
adamw_beta2: float, adamw_beta2: float,
@ -99,48 +194,50 @@ def train(
): ):
assert train_type in ["seq", "sft", "dpo"] assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path) assert os.path.exists(param_path)
parameter = ModelParameter.load(param_path) parameter = ModelParameter.load(param_path)
if window_size is None: if window_size is None:
window_size = parameter.config.max_len window_size = parameter.config.max_len
model = parameter.model model = parameter.model
strategy_kwargs = { strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
"dpo_beta": dpo_beta,
"label_smoothing": label_smoothing
}
dataset = DatasetLoader.load( dataset = DatasetLoader.load(
train_type=train_type, train_type=train_type,
load_path=data_root_path, load_path=data_root_path,
window_size=window_size, window_size=window_size,
stride=stride stride=stride,
)
schedule_config = CosineScheduleConfig(
warmup_steps=warmup_steps,
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
) )
optimizer_fn = partial(create_optimizer, schedule_config = CosineScheduleConfig(
**{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay}) warmup_steps=warmup_steps,
scheduler_fn = partial(create_scheduler, total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
**{"schedule_config": schedule_config}) )
optimizer_fn = partial(
create_optimizer,
**{
"lr": max_lr,
"betas": (adamw_beta1, adamw_beta2),
"weight_decay": adamw_weight_decay,
},
)
scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config})
train_config = TrainConfig( train_config = TrainConfig(
model=model, model=model,
strategy=train_type, strategy=train_type,
dataset=dataset, dataset=dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
checkpoint_dir=checkpoint_dir, ckpt_dir=ckpt_dir,
n_epoch=n_epoch, n_epoch=n_epoch,
batch_size=batch_size, batch_size=batch_size,
start_epoch=start_epoch, start_epoch=start_epoch,
start_batch=start_batch, start_batch=start_batch,
checkpoint_interval=checkpoint_interval, ckpt_interval=ckpt_interval,
accumulation_steps=accumulation_steps, accumulation_steps=accumulation_steps,
max_grad_norm=max_grad_norm, max_grad_norm=max_grad_norm,
random_seed=random_seed, random_seed=random_seed,
@ -152,11 +249,11 @@ def train(
device_type=device_type, device_type=device_type,
extra_kwargs=strategy_kwargs, extra_kwargs=strategy_kwargs,
) )
trainer = Trainer(train_config) trainer = Trainer(train_config)
trainer.train() trainer.train()
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
train(**vars(args)) train(**vars(args))