style: 使用ruff 工具优化代码风格

This commit is contained in:
ViperEkura 2026-03-30 23:32:28 +08:00
parent 345fd2f091
commit 426af2d75f
52 changed files with 1838 additions and 1495 deletions

27
.github/workflows/lint.yml vendored Normal file
View File

@ -0,0 +1,27 @@
name: Lint
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install dependencies
run: |
pip install --upgrade pip
pip install .[dev]
- name: Check formatting with ruff
run: |
ruff format --check .

View File

@ -1,17 +0,0 @@
name: Spell Check
on: [push, pull_request]
permissions:
contents: read
jobs:
spellcheck:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Check spelling in specific files
uses: codespell-project/actions-codespell@v2
with:
check_filenames: true
only_warn: false
path: "**/*.{md, py}"

31
.github/workflows/tests.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: Tests
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install --upgrade pip
pip install .[dev]
- name: Run tests with pytest
run: |
python -m pytest tests/ -v

5
.gitignore vendored
View File

@ -8,5 +8,8 @@
!*.py
!*.md
!*.png
!LICENSE
!pyproject.toml
!pyproject.toml
!.github/workflows/lint.yml
!.github/workflows/tests.yml

View File

@ -54,8 +54,8 @@ python train.py \
--n_epoch=5 \
--batch_size=8 \
--max_lr=2e-4 \
--checkpoint_interval=10000 \
--checkpoint_dir=checkpoints
--ckpt_interval=10000 \
--ckpt_dir=checkpoints
```
**Parameter Explanation:**
@ -67,8 +67,8 @@ python train.py \
- `--accumulation_steps`: Number of batches per training step
- `--warmup_steps`: Warmup steps
- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
- `--checkpoint_interval`: Checkpoint saving interval
- `--checkpoint_dir`: Checkpoint saving directory
- `--ckpt_interval`: Checkpoint saving interval
- `--ckpt_dir`: Checkpoint saving directory
- `--resume_dir`: Resume training from specified path
@ -191,8 +191,8 @@ python train.py \
--n_epoch=5 \
--batch_size=8 \
--max_lr=2e-4 \
--checkpoint_interval=10000 \
--checkpoint_dir=checkpoints
--ckpt_interval=10000 \
--ckpt_dir=checkpoints
```
**参数说明:**
@ -204,8 +204,8 @@ python train.py \
- `--accumulation_steps`: 每个训练步骤的 batch 数量
- `--warmup_steps`: 预热步数warmup steps
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
- `--checkpoint_interval`: 检查点保存间隔
- `--checkpoint_dir`: 检查点保存目录
- `--ckpt_interval`: 检查点保存间隔
- `--ckpt_dir`: 检查点保存目录
- `--resume_dir`: 从指定路径恢复训练

View File

@ -2,13 +2,12 @@ import os
from huggingface_hub import snapshot_download
PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if __name__ == "__main__":
snapshot_download(
repo_id="ViperEk/KHAOSZ",
local_dir=os.path.join(PROJECT_ROOT, "params"),
force_download=True
)
local_dir=os.path.join(PROJECT_ROOT, "params"),
force_download=True,
)

View File

@ -5,18 +5,18 @@ from khaosz.inference.core import disable_random_init
from khaosz.inference.generator import LoopGenerator, GenerationRequest
PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def generate_text():
with disable_random_init():
model_dir = os.path.join(PROJECT_ROOT, "params")
param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16)
param.to(device="cuda", dtype=torch.bfloat16)
query = input(">> ")
request = GenerationRequest(
query=query,
temperature=0.8,
@ -28,8 +28,9 @@ def generate_text():
)
generator = LoopGenerator(param)
response = generator.generate(request)
print(response)
if __name__ == "__main__":
generate_text()
generate_text()

View File

@ -4,18 +4,24 @@ from khaosz.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init
from khaosz.inference.generator import BatchGenerator, GenerationRequest
PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def batch_generate():
with disable_random_init():
model_dir = os.path.join(PROJECT_ROOT, "params")
param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16)
param.to(device="cuda", dtype=torch.bfloat16)
generator = BatchGenerator(param)
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
inputs = [
"你好",
"请问什么是人工智能",
"今天天气如何",
"我感到焦虑, 请问我应该怎么办",
"请问什么是显卡",
]
request = GenerationRequest(
query=inputs,
temperature=0.8,
@ -26,9 +32,10 @@ def batch_generate():
system_prompt=None,
)
responses = generator.generate(request)
for q, r in zip(inputs, responses):
print((q, r))
if __name__ == "__main__":
batch_generate()
batch_generate()

View File

@ -5,16 +5,16 @@ from khaosz.inference.core import disable_random_init
from khaosz.inference.generator import StreamGenerator, GenerationRequest
PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def chat():
with disable_random_init():
model_dir = os.path.join(PROJECT_ROOT, "params")
param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16)
param.to(device="cuda", dtype=torch.bfloat16)
generator = StreamGenerator(param)
history = []
@ -22,7 +22,7 @@ def chat():
query = input(">> ")
if query == "!exit":
break
request = GenerationRequest(
query=query,
temperature=0.8,
@ -32,7 +32,7 @@ def chat():
history=history,
system_prompt=None,
)
response_size = 0
full_response = ""
for response in generator.generate(request):
@ -40,10 +40,10 @@ def chat():
print(response[response_size:], end="", flush=True)
response_size = len(response)
full_response = response
# After generation, update history
history.append((query, full_response.strip()))
if __name__ == "__main__":
chat()
chat()

View File

@ -6,41 +6,30 @@ from khaosz.config import (
TrainConfig,
)
from khaosz.model.transformer import Transformer
from khaosz.data import (
DatasetLoader,
BpeTokenizer
)
from khaosz.data import DatasetLoader, BpeTokenizer
from khaosz.inference.generator import (
GenerationRequest,
LoopGenerator,
StreamGenerator,
BatchGenerator,
EmbeddingEncoder,
GeneratorFactory
)
from khaosz.trainer import (
Trainer,
StrategyFactory,
SchedulerFactory
GeneratorFactory,
)
from khaosz.trainer import Trainer, StrategyFactory, SchedulerFactory
__all__ = [
"Transformer",
"ModelConfig",
"TrainConfig",
"DatasetLoader",
"BpeTokenizer",
"GenerationRequest",
"LoopGenerator",
"StreamGenerator",
"BatchGenerator",
"EmbeddingEncoder",
"GeneratorFactory",
"Trainer",
"StrategyFactory",
"SchedulerFactory"
]
"SchedulerFactory",
]

View File

@ -1,10 +1,10 @@
from khaosz.config.model_config import ModelConfig
from khaosz.config.param_config import BaseModelIO, ModelParameter
from khaosz.config.schedule_config import (
ScheduleConfig,
CosineScheduleConfig,
ScheduleConfig,
CosineScheduleConfig,
SGDRScheduleConfig,
ScheduleConfigFactory
ScheduleConfigFactory,
)
from khaosz.config.train_config import TrainConfig
@ -13,14 +13,12 @@ __all__ = [
# Base I/O
"BaseModelIO",
"ModelParameter",
# Model configuration
"ModelConfig",
"TrainConfig",
# Schedule configuration
"ScheduleConfig",
"CosineScheduleConfig",
"SGDRScheduleConfig",
"ScheduleConfigFactory",
]
]

View File

@ -14,30 +14,29 @@ class ModelConfig:
norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None
tie_weight: Optional[bool] = None
# RoPE
max_len: Optional[int] = None
rope_theta: Optional[float] = None
# GQA
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
def load(self, config_path: str) -> Self:
config = {}
with open(config_path, 'r') as f:
config.update(json.load(f))
with open(config_path, "r") as f:
config.update(json.load(f))
for key, value in config.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def save(self, config_path: str):
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
with open(config_path, 'w') as f:
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)

View File

@ -9,58 +9,57 @@ from khaosz.data.tokenizer import BpeTokenizer
from khaosz.config.model_config import ModelConfig
from khaosz.model.transformer import Transformer
@dataclass
class BaseModelIO:
"""Base class for model I/O operations."""
model: Optional[nn.Module] = field(
default=None,
metadata={"help": "Transformer model."}
default=None, metadata={"help": "Transformer model."}
)
tokenizer: BpeTokenizer = field(
default_factory=BpeTokenizer,
metadata={"help": "Tokenizer for the model."}
default_factory=BpeTokenizer, metadata={"help": "Tokenizer for the model."}
)
config: ModelConfig = field(
default_factory=ModelConfig,
metadata={"help": "Transformer model configuration."}
metadata={"help": "Transformer model configuration."},
)
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
"""Get standardized file paths for model components."""
dir_path = Path(directory)
return {
"model": dir_path / "model.safetensors",
"config": dir_path / "config.json",
"tokenizer": dir_path / "tokenizer.json"
"config": dir_path / "config.json",
"tokenizer": dir_path / "tokenizer.json",
}
def save_components(self, save_dir: Union[str, Path]):
"""Save core model components."""
paths = self._get_file_paths(save_dir)
paths["model"].parent.mkdir(parents=True, exist_ok=True)
if self.model is not None:
st.save_file(self.model.state_dict(), str(paths["model"]))
self.config.save(str(paths["config"]))
self.tokenizer.save(str(paths["tokenizer"]))
def load_components(self, load_dir: Union[str, Path]) -> Self:
"""Load core model components."""
paths = self._get_file_paths(load_dir)
self.config.load(str(paths["config"]))
self.tokenizer.load(str(paths["tokenizer"]))
if self.model is None:
self.model = Transformer(self.config)
if paths["model"].exists():
state_dict = st.load_file(str(paths["model"]))
self.model.load_state_dict(state_dict)
return self
def to(self, *args, **kwargs) -> "BaseModelIO":
"""Move model to device."""
if self.model is not None:
@ -71,13 +70,12 @@ class BaseModelIO:
@dataclass
class ModelParameter(BaseModelIO):
"""Container for model parameters with serialization capabilities."""
@classmethod
def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]):
instance.save_components(save_dir)
@classmethod
def load(cls, load_dir: Union[str, Path]) -> "ModelParameter":
instance = cls()
return instance.load_components(load_dir)

View File

@ -6,35 +6,35 @@ from dataclasses import dataclass, field
@dataclass
class ScheduleConfig(ABC):
"""Base configuration class for learning rate schedulers.
Provides common validation and interface for all schedule types.
"""
schedule_type: str = field(
default="cosine",
metadata={
"help": "Type of learning rate schedule.",
"choices": ["cosine", "sgdr"]
}
"help": "Type of learning rate schedule.",
"choices": ["cosine", "sgdr"],
},
)
warmup_steps: int = field(
default=1000,
metadata={"help": "Number of warmup steps."}
default=1000, metadata={"help": "Number of warmup steps."}
)
min_rate: float = field(
default=0.05,
metadata={"help": "Minimum learning rate multiplier."}
default=0.05, metadata={"help": "Minimum learning rate multiplier."}
)
@abstractmethod
def get_kwargs(self) -> Dict[str, Any]:
"""Get configuration kwargs for scheduler creation."""
raise NotImplementedError
def validate(self) -> None:
"""Validate configuration parameters."""
if self.warmup_steps < 0:
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
raise ValueError(
f"warmup_steps must be non-negative, got {self.warmup_steps}"
)
if not 0 <= self.min_rate <= 1:
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
@ -42,44 +42,43 @@ class ScheduleConfig(ABC):
@dataclass
class CosineScheduleConfig(ScheduleConfig):
"""Cosine annealing learning rate schedule configuration."""
total_steps: int = field(
default=None,
metadata={"help": "Total training steps for cosine schedule."}
default=None, metadata={"help": "Total training steps for cosine schedule."}
)
def __post_init__(self) -> None:
self.schedule_type = "cosine"
self.validate()
def get_kwargs(self) -> Dict[str, Any]:
if self.total_steps is None:
raise ValueError("total_steps must be specified for cosine schedule")
return {
"schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps,
"lr_decay_steps": self.total_steps - self.warmup_steps,
"min_rate": self.min_rate
"min_rate": self.min_rate,
}
def validate(self) -> None:
super().validate()
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
raise ValueError(
f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})"
)
@dataclass
class SGDRScheduleConfig(ScheduleConfig):
"""Stochastic Gradient Descent with Warm Restarts schedule configuration."""
cycle_length: int = field(
default=1000,
metadata={"help": "Length of the first cycle in steps."}
default=1000, metadata={"help": "Length of the first cycle in steps."}
)
t_mult: int = field(
default=2,
metadata={"help": "Multiplier for cycle length growth."}
t_mult: int = field(
default=2, metadata={"help": "Multiplier for cycle length growth."}
)
def __post_init__(self) -> None:
@ -92,9 +91,9 @@ class SGDRScheduleConfig(ScheduleConfig):
"warmup_steps": self.warmup_steps,
"cycle_length": self.cycle_length,
"min_rate": self.min_rate,
"t_mult": self.t_mult
"t_mult": self.t_mult,
}
def validate(self) -> None:
super().validate()
if self.cycle_length <= 0:
@ -105,33 +104,33 @@ class SGDRScheduleConfig(ScheduleConfig):
class ScheduleConfigFactory:
"""Factory class for creating ScheduleConfig instances.
Supports both direct instantiation and factory creation methods.
Example usage:
# Direct creation
config = CosineScheduleConfig(total_steps=10000)
# Factory method
config = ScheduleConfigFactory.create("cosine", total_steps=10000)
"""
CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = {
"cosine": CosineScheduleConfig,
"sgdr": SGDRScheduleConfig,
}
@classmethod
def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig:
"""Create a schedule config instance.
Args:
schedule_type: Type of schedule ("cosine", "sgdr")
**kwargs: Arguments passed to the config constructor
Returns:
ScheduleConfig instance
Raises:
ValueError: If schedule_type is not supported
"""
@ -140,11 +139,11 @@ class ScheduleConfigFactory:
f"Unknown schedule type: '{schedule_type}'. "
f"Supported types: {sorted(cls.CONFIG_MAP.keys())}"
)
config_cls = cls.CONFIG_MAP[schedule_type]
return config_cls(**kwargs)
@classmethod
def available_types(cls) -> list:
"""Return list of available schedule type names."""
return list(cls.CONFIG_MAP.keys())
return list(cls.CONFIG_MAP.keys())

View File

@ -10,127 +10,92 @@ from typing import Callable, List, Optional
@dataclass
class TrainConfig:
# basic setting
model: nn.Module = field(
default=None,
metadata={"help": "Model for training."}
)
strategy: str = field(
default=None,
metadata={"help": "Training strategy."}
)
dataset: Dataset = field(
default=None,
metadata={"help": "Dataset for training."}
)
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
strategy: str = field(default=None, metadata={"help": "Training strategy."})
dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."})
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
default=None,
metadata={"help": "Optimizer factory for training."}
default=None, metadata={"help": "Optimizer factory for training."}
)
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
default=None,
metadata={"help": "Scheduler factory for training."}
)
n_epoch: int = field(
default=1,
metadata={"help": "Number of epochs for training."}
)
batch_size: int = field(
default=4,
metadata={"help": "Batch size for training."}
default=None, metadata={"help": "Scheduler factory for training."}
)
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
accumulation_steps: int = field(
default=1,
metadata={"help": "Number of iterations between steps."}
default=1, metadata={"help": "Number of iterations between steps."}
)
max_grad_norm: float = field(
default=1.0,
metadata={"help": "Maximum gradient norm."}
default=1.0, metadata={"help": "Maximum gradient norm."}
)
# checkpoint setting
start_epoch: int = field(
default=0,
metadata={"help": "Start epoch for training."}
)
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
start_batch: int = field(
default=0,
metadata={"help": "Start batch iteration for training."}
default=0, metadata={"help": "Start batch iteration for training."}
)
checkpoint_dir: str = field(
default="./checkpoint",
metadata={"help": "Checkpoint directory."}
ckpt_dir: str = field(
default="./checkpoint", metadata={"help": "Checkpoint directory."}
)
checkpoint_interval: int = field(
default=5000,
metadata={"help": "Number of iterations between checkpoints."}
ckpt_interval: int = field(
default=5000, metadata={"help": "Number of iterations between checkpoints."}
)
# dataloader setting
random_seed: int = field(
default=3407,
metadata={"help": "Random seed."}
)
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
num_workers: int = field(
default=0,
metadata={"help": "Number of workers for dataloader."}
default=0, metadata={"help": "Number of workers for dataloader."}
)
prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "Prefetch factor for dataloader."}
default=None, metadata={"help": "Prefetch factor for dataloader."}
)
pin_memory: bool = field(
default=False,
metadata={"help": "Pin memory for dataloader."}
default=False, metadata={"help": "Pin memory for dataloader."}
)
# distributed training
nprocs: int = field(
default=1,
metadata={"help": "Number of processes for distributed training."}
default=1, metadata={"help": "Number of processes for distributed training."}
)
backend: str = field(
default="nccl",
metadata={"help": "Distributed training backend."}
default="nccl", metadata={"help": "Distributed training backend."}
)
master_addr: str = field(
default="localhost",
metadata={"help": "Master address for distributed training."}
metadata={"help": "Master address for distributed training."},
)
master_port: str = field(
default="29500",
metadata={"help": "Master port for distributed training."}
default="29500", metadata={"help": "Master port for distributed training."}
)
parallel_wrapper: Optional[Callable] = field(
default=None,
metadata={"help": "Parallel function for training."}
default=None, metadata={"help": "Parallel function for training."}
)
state_dict_fn: Optional[Callable] = field(
default=None,
metadata={"help": "Parallel function for state dict saving."}
default=None, metadata={"help": "Parallel function for state dict saving."}
)
# others
device_ids: Optional[List[int]] = field(
default=None,
metadata={"help": "Device ids for distributed training."}
default=None, metadata={"help": "Device ids for distributed training."}
)
device_type: str = field(
default="cuda",
metadata={"help": "Device type for distributed training."}
default="cuda", metadata={"help": "Device type for distributed training."}
)
extra_kwargs: dict = field(
default_factory=dict,
metadata={"help": "Other arguments."}
default_factory=dict, metadata={"help": "Other arguments."}
)
def __post_init__(self):
self.validate()
def validate(self):
required_fields = ["model", "strategy", "dataset", "optimizer_fn", "scheduler_fn"]
required_fields = [
"model",
"strategy",
"dataset",
"optimizer_fn",
"scheduler_fn",
]
for field_name in required_fields:
if getattr(self, field_name) is None:
raise ValueError(f"{field_name} is required.")

View File

@ -1,12 +1,12 @@
from khaosz.data.dataset import (
BaseDataset,
SEQDataset,
DPODataset,
SFTDataset,
BaseDataset,
SEQDataset,
DPODataset,
SFTDataset,
GRPODataset,
MultiSegmentFetcher,
DatasetLoader,
DatasetFactory
DatasetFactory,
)
from khaosz.data.tokenizer import BpeTokenizer
@ -15,21 +15,17 @@ from khaosz.data.sampler import ResumableDistributedSampler
__all__ = [
# Base classes
"BaseDataset",
# Dataset implementations
"SEQDataset",
"SFTDataset",
"DPODataset",
"GRPODataset",
# Fetchers
"MultiSegmentFetcher",
# Factory (DatasetLoader is alias for backward compatibility)
"DatasetLoader",
"DatasetFactory",
# Tokenizer and sampler
"BpeTokenizer",
"ResumableDistributedSampler"
]
"ResumableDistributedSampler",
]

View File

@ -12,40 +12,42 @@ from typing import Callable, List, Dict, Literal, Optional, Union
class BaseSegmentFetcher:
"""Fetches data segments across multiple tensor segments.
Maintains cumulative lengths for efficient range queries across
multiple discontinuous segments.
"""
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx).
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
Returns:
Concatenated tensor of data in the specified range
"""
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
# Find segment boundaries for the range
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
@ -64,43 +66,44 @@ class BaseSegmentFetcher:
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys.
Each key corresponds to a different type of data (e.g., "sequence", "mask").
"""
def __init__(self, muti_segments: Dict):
self.muti_keys = list(muti_segments.keys())
self.muti_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in muti_segments.items()
key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
len_list = [len(seg) for seg in self.muti_fetchers.values()]
return min(len_list)
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys.
Args:
begin_idx: Starting index
end_idx: Ending index
keys: Single key or list of keys to fetch
Returns:
Dictionary of tensors if multiple keys, single tensor if one key
"""
fetch_dict = {}
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.muti_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
@ -108,10 +111,10 @@ class MultiSegmentFetcher:
class BaseDataset(Dataset, ABC):
"""Abstract base class for all dataset types.
Implements common functionality for window-based data fetching.
"""
def __init__(self, window_size: int, stride: int):
super().__init__()
self.segments = {}
@ -122,38 +125,38 @@ class BaseDataset(Dataset, ABC):
def load(self, load_path: str):
"""Load dataset from HDF5 file.
Args:
load_path: Path to the HDF5 data file
"""
self.segments = load_h5(load_path)
self.fetcher = MultiSegmentFetcher(self.segments)
self.total_samples = len(self.fetcher)
def get_index(self, index: int) -> tuple:
"""Calculate begin and end indices for a sample.
Args:
index: Sample index
Returns:
Tuple of (begin_idx, end_idx)
"""
assert self.total_samples > self.window_size
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
return begin_idx, end_idx
@abstractmethod
def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Get a single sample by index.
Must be implemented by subclasses.
"""
raise NotImplementedError
def __len__(self) -> int:
assert self.total_samples is not None
if self.total_samples <= self.window_size:
@ -163,48 +166,50 @@ class BaseDataset(Dataset, ABC):
class DatasetFactory:
"""Factory class for creating dataset instances.
Supports decorator-based registration for extensible dataset types.
All default dataset types (seq, sft, dpo, grpo) are registered automatically
when their classes are defined with the decorator.
Example usage:
@DatasetFactory.register("custom")
class CustomDataset(BaseDataset):
...
dataset = DatasetFactory.create("custom", window_size, stride)
"""
SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"})
DATASET_MAP: Dict[str, type] = {}
@classmethod
def register(cls, name: str):
"""Decorator to register a new dataset class.
Args:
name: Registration name for the dataset type
Returns:
Decorator function that registers the dataset class
"""
def decorator(dataset_cls: type) -> type:
if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
cls.DATASET_MAP[name] = dataset_cls
return dataset_cls
return decorator
@classmethod
def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset:
"""Create a dataset instance.
Args:
train_type: Type of training ("seq", "sft", "dpo", "grpo")
window_size: Window size for data sampling
stride: Stride between consecutive samples
Returns:
Dataset instance
"""
@ -213,36 +218,42 @@ class DatasetFactory:
f"Unknown dataset type: '{train_type}'. "
f"Supported types: {sorted(cls.SUPPORTED_TYPES)}"
)
if train_type not in cls.DATASET_MAP:
raise NotImplementedError(
f"Dataset type '{train_type}' is supported but not yet implemented."
)
dataset_cls = cls.DATASET_MAP[train_type]
return dataset_cls(window_size, stride)
@classmethod
def load(cls, train_type: str, load_path: str, window_size: int, stride: Optional[int] = None) -> BaseDataset:
def load(
cls,
train_type: str,
load_path: str,
window_size: int,
stride: Optional[int] = None,
) -> BaseDataset:
"""Create and load a dataset in one step.
Args:
train_type: Type of training dataset
load_path: Path to the data file
window_size: Window size for data sampling
stride: Stride between consecutive samples (default: same as window_size)
Returns:
Loaded dataset instance
"""
if stride is None:
stride = window_size
dataset = cls.create(train_type, window_size, stride)
dataset.load(load_path)
return dataset
@classmethod
def available_types(cls) -> list:
"""Return list of registered dataset type names."""
@ -256,46 +267,50 @@ class DatasetFactory:
@DatasetFactory.register("seq")
class SEQDataset(BaseDataset):
"""Dataset for sequential next-token prediction training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
return {"input_ids": x, "target_ids": y}
@DatasetFactory.register("sft")
class SFTDataset(BaseDataset):
"""Dataset for supervised fine-tuning with loss masking."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
dtype=torch.long
)
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
dtype=torch.bool
)
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
@DatasetFactory.register("dpo")
class DPODataset(BaseDataset):
"""Dataset for Direct Preference Optimization training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@ -304,25 +319,34 @@ class DPODataset(BaseDataset):
def __getitem__(self, index: int):
begin_idx, end_idx = self.get_index(index)
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool)
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool)
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(
dtype=torch.bool
)
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(
dtype=torch.bool
)
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
return {
"chosen": chosen,
"rejected": rejected,
"chosen_mask": chosen_mask,
"rejected_mask": rejected_mask,
}
@DatasetFactory.register("grpo")
class GRPODataset(BaseDataset):
"""Dataset for Group Relative Policy Optimization training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx, end_idx = self.get_index(index)
@ -330,8 +354,13 @@ class GRPODataset(BaseDataset):
responses = self._fetch_data(begin_idx, end_idx, "responses")
masks = self._fetch_data(begin_idx, end_idx, "masks")
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
return {
"prompts": prompts,
"responses": responses,
"masks": masks,
"rewards": rewards,
}
# Backward compatibility alias

View File

@ -7,45 +7,45 @@ from typing import Optional
class ResumableDistributedSampler(Sampler[int]):
def __init__(
self,
self,
data_source: Dataset,
start_epoch: int=0,
start_iter: int=0,
seed: int=42,
drop_last: bool=False,
shuffle: bool=True,
process_group: Optional[dist.ProcessGroup]=None,
start_epoch: int = 0,
start_iter: int = 0,
seed: int = 42,
drop_last: bool = False,
shuffle: bool = True,
process_group: Optional[dist.ProcessGroup] = None,
):
self.epoch = start_epoch
self.iter = start_iter
self.seed = seed
self.num_samples = len(data_source)
if process_group is not None:
# input process group
self.rank = dist.get_rank(process_group)
self.num_replicas = dist.get_world_size(process_group)
elif dist.is_available() and dist.is_initialized():
# use default process group
process_group = dist.group.WORLD
self.rank = dist.get_rank()
self.num_replicas = dist.get_world_size()
else:
# single process
self.rank = 0
self.num_replicas = 1
self.drop_last = drop_last
self.shuffle = shuffle
offset = 0 if drop_last else self.num_replicas - 1
offset = 0 if drop_last else self.num_replicas - 1
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
self.total_size = self.num_samples_per_replica * self.num_replicas
self._indices = None
def _get_indices(self):
if self.shuffle:
generator = torch.Generator()
@ -53,26 +53,26 @@ class ResumableDistributedSampler(Sampler[int]):
indices = torch.randperm(self.num_samples, generator=generator).tolist()
else:
indices = torch.arange(self.num_samples).tolist()
if not self.drop_last and self.num_samples < self.total_size:
padding_size = self.total_size - len(indices)
indices += indices[:padding_size]
local_indices = indices[self.rank:self.total_size:self.num_replicas]
local_indices = indices[self.rank : self.total_size : self.num_replicas]
self.iter = self.iter % self.num_samples_per_replica
self._indices = local_indices[self.iter:]
self._indices = local_indices[self.iter :]
def __iter__(self):
if self._indices is None:
self._get_indices()
for i in self._indices:
self.iter += 1
yield i
self.epoch += 1
self._indices = None
def __len__(self):
return self.num_samples_per_replica
return self.num_samples_per_replica

View File

@ -10,24 +10,26 @@ from torch import Tensor
from typing import Any, Dict, List
from khaosz.parallel.setup import get_rank
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, 'w') as f:
with h5py.File(full_file_path, "w") as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
grp.create_dataset(f'data_{idx}', data=arr)
grp.create_dataset(f"data_{idx}", data=arr)
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, 'r') as f:
with h5py.File(h5_file, "r") as f:
for key in f.keys():
grp = f[key]
dsets = []
@ -37,7 +39,7 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
@ -60,7 +62,7 @@ class Checkpoint:
self,
save_dir: str,
) -> None:
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)
@ -72,7 +74,7 @@ class Checkpoint:
}
with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2)
st.save_file(self.state_dict, save_path / f"state_dict.safetensors")
@classmethod
@ -83,7 +85,7 @@ class Checkpoint:
rank = get_rank()
save_path = Path(save_dir)
meta = {}
if rank == 0:
with open(Path(save_dir) / "meta.json", "r") as f:
@ -100,4 +102,4 @@ class Checkpoint:
state_dict=state_dict,
epoch=meta["epoch"],
iteration=meta["iteration"],
)
)

View File

@ -9,34 +9,46 @@ class BpeTokenizer:
def __init__(self, path=None):
self._control_tokens = ["<bos>", "<eos>", "<pad>"]
self._special_tokens = ["<|im_start|>", "<|im_end|>"]
model = BPE()
self._tokenizer = Tokenizer(model)
self._tokenizer.normalizer = normalizers.Sequence([
normalizers.NFC(),
normalizers.Strip()
])
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
pre_tokenizers.UnicodeScripts(),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True)
])
self._tokenizer.normalizer = normalizers.Sequence(
[normalizers.NFC(), normalizers.Strip()]
)
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.UnicodeScripts(),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True),
]
)
self._tokenizer.decoder = decoders.ByteLevel()
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
if path is not None:
self._tokenizer = Tokenizer.from_file(path)
def _prepare_trainer(self, vocab_size: int, min_freq: int, reserved_token_size: int, max_token_length=18) -> tuple:
def _prepare_trainer(
self,
vocab_size: int,
min_freq: int,
reserved_token_size: int,
max_token_length=18,
) -> tuple:
assert reserved_token_size > len(self._special_tokens)
reserved_tokens = [f"<|reserve{i:02d}|>" for i in range(reserved_token_size - len(self._special_tokens))]
detail_vocab_size = vocab_size - (len(reserved_tokens) + len(self._special_tokens))
reserved_tokens = [
f"<|reserve{i:02d}|>"
for i in range(reserved_token_size - len(self._special_tokens))
]
detail_vocab_size = vocab_size - (
len(reserved_tokens) + len(self._special_tokens)
)
alphabet = pre_tokenizers.ByteLevel.alphabet()
min_size = len(alphabet) + len(self._control_tokens)
assert detail_vocab_size > min_size
trainer = BpeTrainer(
vocab_size=detail_vocab_size,
min_frequency=min_freq,
@ -46,61 +58,74 @@ class BpeTokenizer:
initial_alphabet=alphabet,
show_progress=True,
)
return trainer, detail_vocab_size, reserved_tokens
def train(self, files, vocab_size, min_freq, reserved_token_size=100):
trainer, _, reserved_tokens = self._prepare_trainer(
vocab_size=vocab_size,
min_freq=min_freq,
reserved_token_size=reserved_token_size
reserved_token_size=reserved_token_size,
)
self._tokenizer.train(files=files, trainer=trainer)
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
def train_from_iterator(self, iterator, vocab_size, min_freq, reserved_token_size=100):
def train_from_iterator(
self, iterator, vocab_size, min_freq, reserved_token_size=100
):
trainer, _, reserved_tokens = self._prepare_trainer(
vocab_size=vocab_size,
min_freq=min_freq,
reserved_token_size=reserved_token_size
reserved_token_size=reserved_token_size,
)
self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer)
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
def save(self, path):
self._tokenizer.save(path)
def load(self, path):
self._tokenizer = Tokenizer.from_file(path)
def encode(self, tokens: Union[str, List[str]], out_ids: bool=True, add_special_tokens: bool=False) -> List:
def encode(
self,
tokens: Union[str, List[str]],
out_ids: bool = True,
add_special_tokens: bool = False,
) -> List:
if isinstance(tokens, str):
encoded: Encoding = self._tokenizer.encode(tokens, add_special_tokens=add_special_tokens)
encoded: Encoding = self._tokenizer.encode(
tokens, add_special_tokens=add_special_tokens
)
return encoded.ids if out_ids else encoded.tokens
elif isinstance(tokens, list):
encoded_list: List[Encoding] = self._tokenizer.encode_batch(tokens, add_special_tokens=add_special_tokens)
return [encoded.ids if out_ids else encoded.tokens for encoded in encoded_list]
encoded_list: List[Encoding] = self._tokenizer.encode_batch(
tokens, add_special_tokens=add_special_tokens
)
return [
encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
]
def decode(self, tokens: List[int], skip_special_tokens: bool=True) -> str:
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def __len__(self) -> int:
return self._tokenizer.get_vocab_size()
@property
def stop_ids(self) -> List[int]:
stop_token = self._control_tokens + self._special_tokens
stop_ids = [self._tokenizer.token_to_id(token) for token in stop_token]
return stop_ids
@property
def bos_id(self) -> int:
return self._tokenizer.token_to_id("<bos>")
@property
def eos_id(self) -> int:
return self._tokenizer.token_to_id("<eos>")
@property
def pad_id(self) -> int:
return self._tokenizer.token_to_id("<pad>")
return self._tokenizer.token_to_id("<pad>")

View File

@ -11,7 +11,7 @@ from khaosz.inference.generator import (
StreamGenerator,
BatchGenerator,
EmbeddingEncoder,
GeneratorFactory
GeneratorFactory,
)
__all__ = [
@ -19,11 +19,10 @@ __all__ = [
"GeneratorCore",
"EmbeddingEncoderCore",
"KVCacheManager",
"GenerationRequest",
"LoopGenerator",
"StreamGenerator",
"BatchGenerator",
"EmbeddingEncoder",
"GeneratorFactory"
]
"GeneratorFactory",
]

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from torch import Tensor
from torch import Tensor
from contextlib import contextmanager
from typing import Any, Callable, List, Tuple, Union, Optional, Self
from khaosz.config import ModelParameter, ModelConfig
@ -12,58 +12,61 @@ def apply_sampling_strategies(
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf")
filter_value: float = -float("inf"),
) -> Tensor:
"""
"""
Apply sampling strategies to the logits tensor.
Args:
logits (Tensor): The logits tensor.
temperature (float): The temperature parameter.
top_k (int): The top-k parameter.
top_p (float): The top-p parameter.
filter_value (float, optional): The filter value. Defaults to -float("inf").
Returns:
Tensor: The sampled logits tensor.
"""
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1,
index=sorted_indices,
src=sorted_indices_to_remove
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
@contextmanager
def disable_random_init():
init_functions = [
'xavier_normal_', 'xavier_uniform_',
'kaiming_normal_', 'kaiming_uniform_',
'zeros_', 'ones_', 'constant_',
'normal_', 'uniform_'
"xavier_normal_",
"xavier_uniform_",
"kaiming_normal_",
"kaiming_uniform_",
"zeros_",
"ones_",
"constant_",
"normal_",
"uniform_",
]
original_funcs = {}
for name in init_functions:
@ -82,7 +85,7 @@ class GeneratorCore:
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def generate_iterator(
self,
input_ids: Tensor,
@ -91,18 +94,18 @@ class GeneratorCore:
top_p: float,
attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0
)-> Tuple[Tensor, int]:
start_pos: int = 0,
) -> Tuple[Tensor, int]:
with torch.inference_mode():
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
logits = outputs["logits"][:, -1, :]
cache_increase = input_ids.size(-1)
cache_increase = input_ids.size(-1)
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
return next_token_id, cache_increase
def generate_loop(
@ -115,14 +118,21 @@ class GeneratorCore:
attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0,
callback: Optional[Callable[..., Any]] = None
callback: Optional[Callable[..., Any]] = None,
) -> List[int]:
cur_cache_pos = start_pos
for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos)
input_ids,
temperature,
top_k,
top_p,
attn_mask,
kv_caches,
cur_cache_pos,
)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
@ -132,9 +142,9 @@ class GeneratorCore:
if next_token_id.item() in self.tokenizer.stop_ids:
break
return ids
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
@ -145,32 +155,35 @@ class EmbeddingEncoderCore:
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
with_batch = isinstance(sentence, list)
ids = self.tokenizer.encode(sentence)
batch_ids = ids if with_batch else [ids]
max_model_len = self.config.max_len
all_fragments = []
fragment_origin_idx = []
for i, seq in enumerate(batch_ids):
if len(seq) > max_model_len:
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
fragments = [
seq[j : j + max_model_len]
for j in range(0, len(seq), max_model_len)
]
all_fragments.extend(fragments)
fragment_origin_idx.extend([i] * len(fragments))
else:
all_fragments.append(seq)
fragment_origin_idx.append(i)
#if empty fragments
# if empty fragments
if not all_fragments or not ids:
return [] if with_batch else torch.tensor([])
device = next(self.model.parameters()).device
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
padded_ids = []
masks = []
for seq in all_fragments:
@ -179,24 +192,30 @@ class EmbeddingEncoderCore:
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
padded_ids.append(padded_seq)
masks.append(mask)
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
with torch.inference_mode():
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
# [num_fragments, seq_len, hidden_size]
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
sentence_embs: List[Tensor] = []
for i in range(len(batch_ids)):
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
indices = [
idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i
]
if indices:
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
sum_frags = torch.sum(
fragment_embs[indices, :, :], dim=1
) # [frags, hidden_size]
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(
1
) # [frags, 1]
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
sentence_embs.append(emb.flatten())
if with_batch:
return [emb.flatten() for emb in sentence_embs]
else:
@ -209,11 +228,11 @@ class EmbeddingEncoderCore:
class KVCacheManager:
def __init__(
self,
config: ModelConfig,
self,
config: ModelConfig,
batch_size: int,
device: torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16
device: torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
self.batch_size = batch_size
self.device = device
@ -221,25 +240,41 @@ class KVCacheManager:
self.num_layers = config.n_layers
self.max_len = config.max_len
self.num_heads = config.n_kv_heads
self.head_dim = config.dim //config.n_heads
self.head_dim = config.dim // config.n_heads
self._kv_cache: Tuple[Tensor, Tensor] = None
self._seq_mask: Tensor = None
self._initialize()
def _initialize(self):
k_cache = torch.empty(
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
(
self.batch_size,
self.max_len,
self.num_layers,
self.num_heads,
self.head_dim,
),
device=self.device,
dtype=self.dtype,
)
v_cache = torch.empty(
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
(
self.batch_size,
self.max_len,
self.num_layers,
self.num_heads,
self.head_dim,
),
device=self.device,
dtype=self.dtype,
)
self._kv_cache = (k_cache, v_cache)
self._seq_mask = torch.ones((self.batch_size, self.max_len), device=self.device, dtype=torch.bool)
self._seq_mask = torch.ones(
(self.batch_size, self.max_len), device=self.device, dtype=torch.bool
)
def update(self, active_mask: Tensor):
def update(self, active_mask: Tensor):
k_cache, v_cache = self._kv_cache
self._kv_cache = (k_cache[active_mask], v_cache[active_mask])
self._seq_mask = self._seq_mask[active_mask]
@ -250,14 +285,14 @@ class KVCacheManager:
self._seq_mask = None
else:
self._initialize()
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
batch_size, seq_len = input_ids.shape
bool_mask = (input_ids != pad_id)
self._seq_mask[: batch_size, : seq_len] = bool_mask
bool_mask = input_ids != pad_id
self._seq_mask[:batch_size, :seq_len] = bool_mask
def get_kvcache(self) -> Tuple[Tensor, Tensor]:
return self._kv_cache
def get_seq_mask(self) -> Tensor:
return self._seq_mask
def get_seq_mask(self) -> Tensor:
return self._seq_mask

View File

@ -1,6 +1,6 @@
import torch
from dataclasses import dataclass
from torch import Tensor
from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
from khaosz.config.param_config import ModelParameter
@ -8,10 +8,11 @@ from khaosz.config.param_config import ModelParameter
HistoryType = List[Tuple[str, str]]
def build_prompt(
query: str,
system_prompt: Optional[str] = None,
history: Optional[HistoryType] = None
history: Optional[HistoryType] = None,
) -> str:
"""
Build prompt in ChatML format for query and history.
@ -42,17 +43,17 @@ def build_prompt(
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
"""
"""
Pad a list of sequences to a fixed length.
Args:
ids_list (List[List[int]]): A list of sequences.
max_ids_len (int): The maximum length of sequences.
pad_id (int): The id to pad sequences.
Returns:
List[List[int]]: A list of padded sequences.
"""
max_ids_len = max(len(ids) for ids in ids_list)
new_ids_list = []
@ -60,7 +61,7 @@ def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]
pad_len = max_ids_len - len(ids)
padded_seq = [pad_id] * pad_len + ids
new_ids_list.append(padded_seq)
return new_ids_list, max_ids_len
@ -68,7 +69,7 @@ def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]
class GenerationRequest:
"""
Request parameters for text generation.
Attributes:
top_k: Top-k sampling parameter.
top_p: Top-p (nucleus) sampling parameter.
@ -79,6 +80,7 @@ class GenerationRequest:
system_prompt: System prompt for the conversation.
stream: Whether to use streaming generation.
"""
top_k: int
top_p: float
temperature: float
@ -101,63 +103,66 @@ class GenerationRequest:
class LoopGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(self, request: GenerationRequest) -> str:
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, 1, device=device)
prompt = build_prompt(request.query, request.history)
ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids)
self.model.eval()
kv_caches = cache_manager.get_kvcache()
ids = self.generate_loop(
input_ids,
ids,
request.temperature,
request.top_k,
request.top_p,
kv_caches=kv_caches,
input_ids,
ids,
request.temperature,
request.top_k,
request.top_p,
kv_caches=kv_caches,
)
response = self.tokenizer.decode(ids[start_cache_pos:])
return response
class StreamGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, 1, device=device)
prompt = build_prompt(request.query, request.history)
ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
kv_caches = cache_manager.get_kvcache()
for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids, request.temperature, request.top_k, request.top_p,
kv_caches=kv_caches,
start_pos=cur_cache_pos
input_ids,
request.temperature,
request.top_k,
request.top_p,
kv_caches=kv_caches,
start_pos=cur_cache_pos,
)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
response = self.tokenizer.decode(ids[start_cache_pos:])
yield response
if next_token_id.item() in self.tokenizer.stop_ids:
yield response + "\n"
break
@ -166,131 +171,140 @@ class StreamGenerator(GeneratorCore):
class BatchGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(self, request: GenerationRequest) -> List[str]:
batch_size = len(request.query)
if request.history is None:
request.history = [[] for _ in range(batch_size)]
prompts = [build_prompt(query, history) for query, history in zip(request.query, request.history)]
prompts = [
build_prompt(query, history)
for query, history in zip(request.query, request.history)
]
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, batch_size, device=device)
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
activate_task_mask = [True] * batch_size
start_cache_pos = max_ids_len
cur_cache_pos = 0
while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0:
kv_caches = cache_manager.get_kvcache()
attn_mask =cache_manager.get_seq_mask()
attn_mask = cache_manager.get_seq_mask()
next_token_id, cache_increase = self.generate_iterator(
input_tensor, request.temperature, request.top_k, request.top_p,
attn_mask=attn_mask,
kv_caches=kv_caches,
start_pos=cur_cache_pos
input_tensor,
request.temperature,
request.top_k,
request.top_p,
attn_mask=attn_mask,
kv_caches=kv_caches,
start_pos=cur_cache_pos,
)
cur_cache_pos += cache_increase
active_mask = []
c_ids = 0
for i in range(batch_size):
if activate_task_mask[i]:
token = next_token_id[c_ids, :].item()
ids_list[i].append(token)
c_ids += 1
is_active = not token in self.tokenizer.stop_ids
activate_task_mask[i] = is_active
activate_task_mask[i] = is_active
active_mask.append(is_active)
active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool)
cache_manager.update(active_mask)
input_tensor = next_token_id[active_mask, :]
max_ids_len += 1
responses = [str()] * batch_size
for i in range(batch_size):
responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
request.history[i].append((request.query[i], responses[i]))
return responses
class EmbeddingEncoder(EmbeddingEncoderCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
return super().encode(sentence)
class GeneratorFactory:
"""Factory class for creating generator instances.
Provides smart generator selection based on request characteristics:
- Streaming: Use StreamGenerator for streaming output
- Batch: Use BatchGenerator when query is a list
- Single: Use LoopGenerator for single query non-streaming
Example usage:
generator = GeneratorFactory.create_generator(parameter, request)
result = generator.generate(request)
"""
@staticmethod
def create_generator(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
def create_generator(
parameter: ModelParameter, request: GenerationRequest
) -> GeneratorCore:
"""Create a generator based on request characteristics.
Args:
parameter: Model parameters containing model, tokenizer, config
request: Generation request with query, options, etc.
Returns:
Appropriate GeneratorCore subclass instance
"""
# Streaming generation: check stream field first
if request.stream:
return StreamGenerator(parameter)
# Batch generation: query is a list of strings
if isinstance(request.query, list):
return BatchGenerator(parameter)
# Default: single query non-streaming
return LoopGenerator(parameter)
@staticmethod
def create_encoder(parameter: ModelParameter) -> EmbeddingEncoderCore:
"""Create an embedding encoder instance.
Args:
parameter: Model parameters
Returns:
EmbeddingEncoderCore instance
"""
return EmbeddingEncoder(parameter)
@classmethod
def create(cls, parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
def create(
cls, parameter: ModelParameter, request: GenerationRequest
) -> GeneratorCore:
"""Convenience method that delegates to create_generator.
Args:
parameter: Model parameters
request: Generation request
Returns:
Generator instance
"""
return cls.create_generator(parameter, request)

View File

@ -1,17 +1,10 @@
from khaosz.model.module import (
from khaosz.model.module import (
Linear,
RMSNorm,
RMSNorm,
MLP,
GQA,
DecoderBlock,
)
from khaosz.model.transformer import Transformer
__all__ = [
"Linear",
"RMSNorm",
"MLP",
"GQA",
"DecoderBlock",
"Transformer"
]
__all__ = ["Linear", "RMSNorm", "MLP", "GQA", "DecoderBlock", "Transformer"]

View File

@ -7,7 +7,7 @@ from typing import Optional, Tuple
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
"""
"""
Repeat k times along the dimension for attention heads.
Args:
x (Tensor): The input tensor.
@ -15,7 +15,7 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
Returns:
Tensor: The repeated tensor.
"""
bs, slen, n_heads, head_dim = x.shape
if n_rep == 1:
return x
@ -25,12 +25,13 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
.reshape(bs, slen, n_heads * n_rep, head_dim)
)
def get_rotary_emb(
dim: int,
max_len: int,
base: float = 10000,
) -> Tuple[Tensor, Tensor]:
"""
dim: int,
max_len: int,
base: float = 10000,
) -> Tuple[Tensor, Tensor]:
"""
Get the rotary embedding for the given dimension and maximum length.
Args:
dim (int): The dimension of the input.
@ -46,6 +47,7 @@ def get_rotary_emb(
return torch.cos(freqs).float(), torch.sin(freqs).float()
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
"""
Apply rotary embedding to the input tensor using cos/sin form.
@ -55,49 +57,49 @@ def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tens
Returns:
Tensor: The output tensor (rotated, same shape as input).
"""
dtype = x.dtype
cos, sin = rotary_emb
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
x_real = x[..., 0::2] # [batch, seq_len, dim//2]
x_imag = x[..., 1::2] # [batch, seq_len, dim//2]
x_real_rot = x_real * cos - x_imag * sin
x_imag_rot = x_real * sin + x_imag * cos
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
return x_out.to(dtype)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int, base: int=10000):
def __init__(self, dim: int, max_len: int, base: int = 10000):
super().__init__()
self.dim = dim
self.max_len = max_len
self.base = base
self.max_len_cached = None
self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int):
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base)
self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False)
self.max_len_cached = max_len
def forward(self, x: Tensor, start_pos: int=0) -> Tuple[Tensor, Tensor]:
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
seq_len = x.size(1)
if self.max_len_cached < seq_len + start_pos:
self._set_rotary_buffer(seq_len + start_pos)
cos = self.cos_cached[start_pos : start_pos + seq_len]
sin = self.sin_cached[start_pos : start_pos + seq_len]
return (cos, sin)
@ -115,43 +117,42 @@ class RMSNorm(nn.Module):
def __init__(self, dim, norm_eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.normalized_shape = (dim, )
self.normalized_shape = (dim,)
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps)
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps)
return rms.to(x.dtype)
class MLP(nn.Module):
def __init__(self, dim: int, dim_feed_forward: int):
super().__init__()
self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(dim_feed_forward, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out
class GQA(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
self,
dim: int,
n_heads: int,
n_kv_heads: int,
use_qk_norm: bool,
norm_eps: float,
use_gated_attention: bool,
layer_id: int
layer_id: int,
):
super().__init__()
assert dim % n_heads == 0
assert n_heads % n_kv_heads == 0
self.head_dim = dim // n_heads
self.layer_id = layer_id
self.dim = dim
@ -165,11 +166,11 @@ class GQA(nn.Module):
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
self.o_proj = Linear(dim, dim)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, norm_eps)
self.k_norm = RMSNorm(self.head_dim, norm_eps)
if self.use_gated_attention:
self.gate = Linear(dim, dim)
@ -177,14 +178,14 @@ class GQA(nn.Module):
batch_size, seq_len, _ = x.shape
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
return x
def forward(
self,
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0
start_pos: int = 0,
) -> Tensor:
bsz, seq_len, _ = x.size()
is_causal = mask is None
@ -194,31 +195,36 @@ class GQA(nn.Module):
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
if self.use_qk_norm:
q, k = self.q_norm(q), self.k_norm(k)
if kv_cache is not None:
k_cache, v_cache = kv_cache
# copy to cache
k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
# get cache
k = k_cache[:bsz, :start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal).permute(0, 2, 1, 3).contiguous().flatten(2)
sdqa_out = (
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
.permute(0, 2, 1, 3)
.contiguous()
.flatten(2)
)
if self.use_gated_attention:
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
out = self.o_proj(sdqa_out)
return out
@ -227,15 +233,15 @@ class GQA(nn.Module):
class MLA(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
dim: int,
n_heads: int,
n_kv_heads: int,
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
norm_eps: float,
use_gated_attention: bool,
layer_id: int
layer_id: int,
):
super().__init__()
self.dim = dim
@ -252,45 +258,46 @@ class MLA(nn.Module):
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps)
# KV (k_nope, k_rope, v)
self.kv_b_proj = Linear(
kv_lora_rank,
kv_lora_rank,
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
)
self.o_proj = Linear(dim, dim, bias=False)
if use_gated_attention:
self.gate = Linear(dim, dim, bias=False)
def forward(
self,
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0
start_pos: int = 0,
) -> Tensor:
bsz, seq_len, _ = x.size()
is_causal = mask is None
q = self.q_proj(x)
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
kv_compressed = self.kv_a_proj(x)
kv_compressed = self.kv_norm(kv_compressed)
kv = self.kv_b_proj(kv_compressed)
kv = kv.view(bsz, seq_len, self.n_kv_heads, -1)
k_nope, k_rope, v = torch.split(
kv,
[self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim],
dim=-1
kv, [self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], dim=-1
)
q_nope, q_rope = (
q[..., : self.qk_nope_head_dim],
q[..., self.qk_rope_head_dim :],
)
q_nope, q_rope = q[..., :self.qk_nope_head_dim], q[..., self.qk_rope_head_dim:]
q_rope = apply_rotary_emb(q_rope, rotary_emb)
k_rope = apply_rotary_emb(k_rope, rotary_emb)
@ -299,41 +306,48 @@ class MLA(nn.Module):
if kv_cache is not None:
k_cache, v_cache = kv_cache
k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v
k = k_cache[:bsz, :start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
if self.use_gated_attention:
attn_out = attn_out * F.sigmoid(self.gate(x))
out = self.o_proj(attn_out)
return out
class DecoderBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: int,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int
self,
dim: int,
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: int,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
):
super().__init__()
self.attention = GQA(dim, n_heads, n_kv_heads,
use_qk_norm, norm_eps, use_gated_attention, layer_id)
self.attention = GQA(
dim,
n_heads,
n_kv_heads,
use_qk_norm,
norm_eps,
use_gated_attention,
layer_id,
)
self.input_norm = RMSNorm(dim, norm_eps)
self.mlp = MLP(dim, dim_ffn)
self.post_attention_norm = RMSNorm(dim, norm_eps)
@ -341,24 +355,20 @@ class DecoderBlock(nn.Module):
def forward(
self,
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
rotary_emb: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0
start_pos: int = 0,
) -> Tensor:
# attention
attn_output = self.attention(
self.input_norm(x),
rotary_emb,
attention_mask,
kv_cache,
start_pos
self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos
)
x = attn_output + x
# feed forward
x = self.mlp(self.post_attention_norm(x)) + x
return x
@ -366,6 +376,6 @@ class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)
return F.embedding(x, self.weight)

View File

@ -4,15 +4,21 @@ import torch.nn as nn
from torch import Tensor
from typing import Any, Mapping, Optional, Tuple
from khaosz.config.model_config import ModelConfig
from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, RotaryEmbedding
from khaosz.model.module import (
Embedding,
DecoderBlock,
Linear,
RMSNorm,
RotaryEmbedding,
)
def process_attention_mask(
seq_mask: Tensor,
input_tensor: Tensor,
start_pos: int = 0,
is_causal: bool = False,
) -> Tensor:
seq_mask: Tensor,
input_tensor: Tensor,
start_pos: int = 0,
is_causal: bool = False,
) -> Tensor:
"""
Create attention mask for GQA
Args:
@ -26,32 +32,36 @@ def process_attention_mask(
device = input_tensor.device
dtype = input_tensor.dtype
seq_len = input_tensor.size(1)
if seq_mask is None:
if start_pos != 0:
# for single prompt chat
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
else:
return None
if seq_mask.dim() > 2:
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
# if ndim > 2, it's 4D tensor
return seq_mask
batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand(
batch_size, seq_len, start_pos + seq_len
)
# (bsz, seq_len, start_pos + seq_len)
if is_causal:
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
attention_mask = attention_mask.masked_fill_(
~expanded_mask, -torch.finfo(dtype).max / 2
).unsqueeze(1)
# (bsz, 1, seq_len, seq_len + start_pos)
return attention_mask
@ -59,26 +69,38 @@ class Transformer(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.rotary_embeding = RotaryEmbedding(config.dim // config.n_heads, config.max_len)
self.rotary_embeding = RotaryEmbedding(
config.dim // config.n_heads, config.max_len
)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList([
DecoderBlock(config.dim, config.n_heads, config.dim_ffn, config.n_kv_heads,
config.norm_eps, config.use_qk_norm, config.use_gated_attention, layer_id)
for layer_id in range(config.n_layers)
])
self.layers = nn.ModuleList(
[
DecoderBlock(
config.dim,
config.n_heads,
config.dim_ffn,
config.n_kv_heads,
config.norm_eps,
config.use_qk_norm,
config.use_gated_attention,
layer_id,
)
for layer_id in range(config.n_layers)
]
)
self.norm = RMSNorm(config.dim, config.norm_eps)
self.lm_head = Linear(config.dim, config.vocab_size)
if self.config.tie_weight == True:
self.lm_head.weight = self.embed_tokens.weight
self._init_parameters()
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
lm_head_key = 'lm_head.weight'
embed_key = 'embed_tokens.weight'
lm_head_key = "lm_head.weight"
embed_key = "embed_tokens.weight"
if self.config.tie_weight == True:
# same tensor
@ -87,48 +109,44 @@ class Transformer(nn.Module):
if lm_head_key not in state_dict and embed_key in state_dict:
# use clone to avoid sharing the same tensor
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
return super().load_state_dict(state_dict, strict, assign)
def state_dict(self, destination=None, prefix='', keep_vars=False):
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
def state_dict(self, destination=None, prefix="", keep_vars=False):
state_dict = super().state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
if self.config.tie_weight == True:
lm_head_key = prefix + 'lm_head.weight'
lm_head_key = prefix + "lm_head.weight"
if lm_head_key in state_dict:
del state_dict[lm_head_key]
return state_dict
def _init_parameters(self):
for param in self.parameters():
if param.dim() > 1:
nn.init.normal_(param, mean=0.0, std=0.006)
def forward(
self,
input_ids: Tensor,
input_mask: Optional[Tensor]=None,
persistent_key_values: Optional[Tuple[Tensor, Tensor]]=None,
start_pos: int = 0
self,
input_ids: Tensor,
input_mask: Optional[Tensor] = None,
persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0,
) -> Tensor:
assert input_ids.ndim == 2
x = self.embed_tokens(input_ids)
rotary_emb = self.rotary_embeding(x, start_pos)
attn_mask = process_attention_mask(
input_mask, x, start_pos, is_causal=True
)
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos)
hidden_states = self.norm(x)
hidden_states = self.norm(x)
logits = self.lm_head(hidden_states)
return {
"logits": logits,
"hidden_states": hidden_states
}
return {"logits": logits, "hidden_states": hidden_states}

View File

@ -1,27 +1,21 @@
from khaosz.parallel.setup import (
get_world_size,
get_world_size,
get_rank,
get_current_device,
only_on_rank,
setup_parallel,
spawn_parallel_fn
setup_parallel,
spawn_parallel_fn,
)
from khaosz.parallel.module import (
RowParallelLinear,
ColumnParallelLinear
)
from khaosz.parallel.module import RowParallelLinear, ColumnParallelLinear
__all__ = [
"get_world_size",
"get_rank",
"get_current_device",
"only_on_rank",
"setup_parallel",
"spawn_parallel_fn",
"RowParallelLinear",
"ColumnParallelLinear"
"ColumnParallelLinear",
]

View File

@ -17,91 +17,99 @@ class ParallelModel(nn.Module):
class RowParallelLinear(ParallelModel):
def __init__(
self,
self,
process_group: dist.ProcessGroup,
in_features: int,
out_features: int,
in_features: int,
out_features: int,
bias: bool = True,
reduce_results: bool = True
reduce_results: bool = True,
):
super().__init__(process_group)
self.in_features = in_features
self.out_features = out_features
self.in_features_per_rank = in_features // self.world_size
self.reduce_results = reduce_results
if in_features % self.world_size != 0:
raise ValueError(f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}")
raise ValueError(
f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}"
)
self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank))
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
def forward(self, input: Tensor) -> Tensor:
output = F.linear(input, self.weight)
if self.reduce_results:
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
if self.bias is not None:
output += self.bias
return output
def load_state_dict(self, state_dict: Dict[str, Tensor]):
full_weight = state_dict.get('weight')
full_bias = state_dict.get('bias')
def load_state_dict(self, state_dict: Dict[str, Tensor]):
full_weight = state_dict.get("weight")
full_bias = state_dict.get("bias")
start_idx = self.rank * self.in_features_per_rank
end_idx = start_idx + self.in_features_per_rank
weight_slice = full_weight[:, start_idx:end_idx]
self.weight.data.copy_(weight_slice)
if self.bias is not None:
self.bias.data.copy_(full_bias)
class ColumnParallelLinear(ParallelModel):
def __init__(
self,
self,
process_group: dist.ProcessGroup,
in_features: int,
out_features: int,
in_features: int,
out_features: int,
bias: bool = True,
gather_results: bool = True
gather_results: bool = True,
):
super().__init__(process_group)
self.in_features = in_features
self.out_features = out_features
self.out_features_per_rank = out_features // self.world_size
self.gather_results = gather_results
if out_features % self.world_size != 0:
raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}")
raise ValueError(
f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}"
)
self.weight = nn.Parameter(
torch.empty(self.out_features_per_rank, self.in_features)
)
self.bias = (
nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None
)
self.weight = nn.Parameter(torch.empty(self.out_features_per_rank, self.in_features))
self.bias = nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None
def forward(self, input: Tensor) -> Tensor:
output = F.linear(input, self.weight, self.bias)
if self.gather_results:
output_list = [torch.empty_like(output) for _ in range(self.world_size)]
dist.all_gather(output_list, output, group=self.process_group)
output = torch.cat(output_list, dim=-1)
return output
def load_state_dict(self, state_dict: Dict[str, Tensor]):
full_weight = state_dict.get('weight')
full_bias = state_dict.get('bias')
def load_state_dict(self, state_dict: Dict[str, Tensor]):
full_weight = state_dict.get("weight")
full_bias = state_dict.get("bias")
start_idx = self.rank * self.out_features_per_rank
end_idx = start_idx + self.out_features_per_rank
weight_slice = full_weight[start_idx:end_idx, :]
self.weight.data.copy_(weight_slice)
if self.bias is not None:
bias_slice = full_bias[start_idx:end_idx]
self.bias.data.copy_(bias_slice)
self.bias.data.copy_(bias_slice)

View File

@ -11,73 +11,74 @@ from typing import Callable, List, Optional
def get_current_device():
return os.environ["LOCAL_DEVICE"]
def get_world_size() -> int:
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1
def get_rank() -> int:
if dist.is_available() and dist.is_initialized():
return dist.get_rank()
else:
return 0
@contextmanager
def setup_parallel(
rank: int,
rank: int,
world_size: int,
backend: str = "nccl",
master_addr: str = "localhost",
master_port: str = "29500",
device_type: str = "cuda",
device_ids: Optional[List[int]] = None
device_ids: Optional[List[int]] = None,
):
if dist.is_available() and dist.is_initialized():
yield dist.group.WORLD
return
return
if world_size <= 1:
yield None
return
if device_ids is None:
device_ids = [i for i in range(world_size)]
rank = device_ids[rank % len(device_ids)]
device_id = torch.device(device_type, device_ids[rank])
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = master_port
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_DEVICE"] = str(device_id)
dist.init_process_group(
rank=rank,
world_size=world_size,
backend=backend,
device_id=device_id
rank=rank, world_size=world_size, backend=backend, device_id=device_id
)
try:
if backend == "nccl" and torch.cuda.is_available():
torch.cuda.set_device(device_id)
elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available():
elif backend == "ccl" and hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.set_device(device_id)
yield dist.group.WORLD
finally:
if dist.is_initialized():
dist.destroy_process_group()
def only_on_rank(rank, sync=False):
"""
decorator to run a function only on a specific rank.
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
@ -89,67 +90,81 @@ def only_on_rank(rank, sync=False):
dist.barrier()
return ret_args
return wrapper
return decorator
def wrapper_spawn_func(
rank: int,
world_size: int,
backend: str,
master_addr: str,
rank: int,
world_size: int,
backend: str,
master_addr: str,
master_port: str,
device_type: str,
device_ids: List[int],
func: Callable,
kwargs: dict
device_ids: List[int],
func: Callable,
kwargs: dict,
):
try:
with setup_parallel(
rank=rank,
world_size=world_size,
backend=backend,
master_addr=master_addr,
rank=rank,
world_size=world_size,
backend=backend,
master_addr=master_addr,
master_port=master_port,
device_type=device_type,
device_ids=device_ids
device_ids=device_ids,
):
func(**kwargs)
except Exception as e:
print(f"Error in rank {rank}: {e}")
raise
def spawn_parallel_fn(
func: Callable,
world_size: int,
func: Callable,
world_size: int,
backend: str = "nccl",
master_addr: str = "localhost",
master_port: str = "29500",
device_type: str = "cuda",
device_ids: Optional[List[int]] = None,
**kwargs
**kwargs,
):
# clear environment variables
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK', 'LOCAL_DEVICE']:
for key in [
"MASTER_ADDR",
"MASTER_PORT",
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
"LOCAL_DEVICE",
]:
if key in os.environ:
del os.environ[key]
if world_size == 1:
device_ids = device_ids or [0]
device_id = torch.device(device_type, device_ids[0])
os.environ["LOCAL_DEVICE"] = str(device_id)
func(**kwargs)
return
wrapper_spawn_func_args = (world_size, backend, master_addr, master_port,
device_type, device_ids, func, kwargs)
wrapper_spawn_func_args = (
world_size,
backend,
master_addr,
master_port,
device_type,
device_ids,
func,
kwargs,
)
mp.spawn(
wrapper_spawn_func,
nprocs=world_size,
args=wrapper_spawn_func_args,
join=True
)
wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True
)

View File

@ -14,15 +14,12 @@ from khaosz.trainer.train_callback import (
__all__ = [
# Main trainer
"Trainer",
# Strategy factory
"StrategyFactory",
"BaseStrategy",
# Scheduler factory
"SchedulerFactory",
"BaseScheduler",
# Callbacks
"TrainCallback",
"GradientClippingCallback",
@ -30,4 +27,4 @@ __all__ = [
"CheckpointCallback",
"ProgressBarCallback",
"MetricLoggerCallback",
]
]

View File

@ -1,8 +1,9 @@
import torch.nn as nn
from typing import Dict
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
""" Compute gradient norm for each parameter in the model. """
"""Compute gradient norm for each parameter in the model."""
norms = {}
for name, param in model.named_parameters():
norms[name] = 0.0
@ -11,8 +12,9 @@ def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
norms[name] = norm
return norms
def grad_std(model: nn.Module) -> Dict[str, float]:
""" Compute standard deviation of gradients for each parameter. """
"""Compute standard deviation of gradients for each parameter."""
stds = {}
for name, param in model.named_parameters():
stds[name] = 0.0
@ -21,41 +23,45 @@ def grad_std(model: nn.Module) -> Dict[str, float]:
stds[name] = std
return stds
def grad_max(model: nn.Module) -> Dict[str, float]:
""" Find the maximum absolute gradient value for each parameter. """
"""Find the maximum absolute gradient value for each parameter."""
max_vals = {}
for name, param in model.named_parameters():
max_vals[name] = -float('inf')
max_vals[name] = -float("inf")
if param.grad:
max_val = param.grad.data.max().item()
max_vals[name] = max_val
return max_vals
def grad_min(model: nn.Module) -> Dict[str, float]:
""" Find the minimum absolute gradient value for each parameter. """
"""Find the minimum absolute gradient value for each parameter."""
min_vals = {}
for name, param in model.named_parameters():
min_vals[name] = float('inf')
min_vals[name] = float("inf")
if param.grad:
min_val = param.grad.data.min().item()
min_vals[name] = min_val
return min_vals
def grad_mean(model: nn.Module) -> Dict[str, float]:
""" Compute mean of gradients for each parameter. """
"""Compute mean of gradients for each parameter."""
means = {}
for name, param in model.named_parameters():
means[name] = 0.0
if param.grad:
mean = param.grad.data.mean().item()
means[name] = mean
return means
def grad_nan_num(model: nn.Module) -> Dict[str, int]:
""" Count the number of NaNs in gradients for each parameter. """
"""Count the number of NaNs in gradients for each parameter."""
nan_nums = {}
for name, param in model.named_parameters():
nan_nums[name] = 0
@ -64,26 +70,34 @@ def grad_nan_num(model: nn.Module) -> Dict[str, int]:
nan_nums[name] = nan_num
return nan_nums
def ctx_get_loss(ctx):
return ctx.loss
def ctx_get_lr(ctx):
return ctx.optimizer.param_groups[-1]['lr']
return ctx.optimizer.param_groups[-1]["lr"]
def ctx_get_grad_norm(ctx):
return grad_norm(ctx.model)
def ctx_get_grad_std(ctx):
return grad_std(ctx.model)
def ctx_get_grad_max(ctx):
return grad_max(ctx.model)
def ctx_get_grad_min(ctx):
return grad_min(ctx.model)
def ctx_get_grad_mean(ctx):
return grad_mean(ctx.model)
def ctx_get_grad_nan_num(ctx):
return grad_nan_num(ctx.model)
return grad_nan_num(ctx.model)

View File

@ -9,71 +9,75 @@ from khaosz.config.schedule_config import ScheduleConfig
class BaseScheduler(LRScheduler, ABC):
"""Base scheduler class for all other schedulers."""
def __init__(self, optimizer, last_epoch: int = -1):
super().__init__(optimizer, last_epoch)
@abstractmethod
def get_lr(self) -> List[float]:
"""Calculate the current learning rate."""
raise NotImplementedError
def state_dict(self) -> Dict[str, Any]:
return super().state_dict()
def load_state_dict(self, state_dict: Dict[str, Any]):
super().load_state_dict(state_dict)
class SchedulerFactory:
"""Factory class for creating learning rate schedulers.
Supports decorator-based registration for extensible scheduler types.
Also supports creation from ScheduleConfig objects.
Example usage:
@SchedulerFactory.register("custom")
class CustomScheduler(BaseScheduler):
...
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs)
# Or from config
config = CosineScheduleConfig(total_steps=10000)
scheduler = SchedulerFactory.load(optimizer, config)
"""
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {}
@classmethod
def register(cls, name: str):
"""Decorator to register a new scheduler class.
Args:
name: Registration name for the scheduler
Returns:
Decorator function that registers the scheduler class
"""
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
raise TypeError(
f"{scheduler_cls.__name__} must inherit from BaseScheduler"
)
cls.SCHEDULER_MAP[name] = scheduler_cls
return scheduler_cls
return decorator
@classmethod
def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler:
"""Create a scheduler instance by type name.
Args:
optimizer: PyTorch optimizer
schedule_type: Type of scheduler ("cosine", "sgdr")
**kwargs: Arguments passed to the scheduler constructor
Returns:
Scheduler instance
Raises:
ValueError: If schedule_type is not supported
"""
@ -82,25 +86,25 @@ class SchedulerFactory:
f"Unknown schedule type: '{schedule_type}'. "
f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}"
)
scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
return scheduler_cls(optimizer, **kwargs)
@staticmethod
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
"""Create a scheduler from a ScheduleConfig object.
Args:
optimizer: PyTorch optimizer
schedule_config: ScheduleConfig instance
Returns:
Scheduler instance
"""
kwargs = schedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type")
return SchedulerFactory.create(optimizer, schedule_type, **kwargs)
@classmethod
def available_types(cls) -> list:
"""Return list of registered scheduler type names."""
@ -114,22 +118,21 @@ class SchedulerFactory:
@SchedulerFactory.register("cosine")
class CosineScheduler(BaseScheduler):
"""Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler."""
def __init__(
self,
optimizer,
warmup_steps: int,
lr_decay_steps: int,
min_rate: float = 0.05,
last_epoch: int = -1
self,
optimizer,
warmup_steps: int,
lr_decay_steps: int,
min_rate: float = 0.05,
last_epoch: int = -1,
):
self.warmup_steps = warmup_steps
self.lr_decay_steps = lr_decay_steps
self.min_rate = min_rate
self.total_steps = warmup_steps + lr_decay_steps
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
# warmup
if self.last_epoch < self.warmup_steps:
@ -142,46 +145,47 @@ class CosineScheduler(BaseScheduler):
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress))
decay_factor = max(self.min_rate, cosine_decay)
return [base_lr * decay_factor for base_lr in self.base_lrs]
def state_dict(self):
state = super().state_dict()
state.update({
'warmup_steps': self.warmup_steps,
'lr_decay_steps': self.lr_decay_steps,
'min_rate': self.min_rate,
'total_steps': self.total_steps,
})
state.update(
{
"warmup_steps": self.warmup_steps,
"lr_decay_steps": self.lr_decay_steps,
"min_rate": self.min_rate,
"total_steps": self.total_steps,
}
)
return state
def load_state_dict(self, state_dict):
self.warmup_steps = state_dict.pop('warmup_steps')
self.lr_decay_steps = state_dict.pop('lr_decay_steps')
self.min_rate = state_dict.pop('min_rate')
self.total_steps = state_dict.pop('total_steps')
self.warmup_steps = state_dict.pop("warmup_steps")
self.lr_decay_steps = state_dict.pop("lr_decay_steps")
self.min_rate = state_dict.pop("min_rate")
self.total_steps = state_dict.pop("total_steps")
super().load_state_dict(state_dict)
@SchedulerFactory.register("sgdr")
class SGDRScheduler(BaseScheduler):
"""SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler."""
def __init__(
self,
optimizer,
warmup_steps: int,
cycle_length: int,
min_rate: float = 0.05,
t_mult: int = 2,
last_epoch: int = -1,
self,
optimizer,
warmup_steps: int,
cycle_length: int,
min_rate: float = 0.05,
t_mult: int = 2,
last_epoch: int = -1,
):
self.warmup_steps = warmup_steps
self.cycle_length = cycle_length
self.min_rate = min_rate
self.t_mult = t_mult
super().__init__(optimizer, last_epoch)
def get_lr(self):
# warmup
if self.last_epoch < self.warmup_steps:
@ -190,40 +194,44 @@ class SGDRScheduler(BaseScheduler):
# SGDR
steps_since_warmup = self.last_epoch - self.warmup_steps
# 1. Calculate current cycle and position within cycle
current_cycle_length = self.cycle_length
total_cycles_length = 0
cycle_num = 0
while total_cycles_length + current_cycle_length <= steps_since_warmup:
total_cycles_length += current_cycle_length
current_cycle_length *= self.t_mult
cycle_num += 1
steps_in_cycle = steps_since_warmup - total_cycles_length
# 2. Cosine annealing within the current cycle
cosine_factor = 0.5 * (1 + math.cos(math.pi * steps_in_cycle / current_cycle_length))
cosine_factor = 0.5 * (
1 + math.cos(math.pi * steps_in_cycle / current_cycle_length)
)
learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor
return [base_lr * learning_rate_factor for base_lr in self.base_lrs]
def state_dict(self):
"""Returns the state of the scheduler as a dict."""
state = super().state_dict()
state.update({
'warmup_steps': self.warmup_steps,
'cycle_length': self.cycle_length,
'min_rate': self.min_rate,
't_mult': self.t_mult
})
state.update(
{
"warmup_steps": self.warmup_steps,
"cycle_length": self.cycle_length,
"min_rate": self.min_rate,
"t_mult": self.t_mult,
}
)
return state
def load_state_dict(self, state_dict):
"""Loads the scheduler's state."""
self.warmup_steps = state_dict.pop('warmup_steps')
self.cycle_length = state_dict.pop('cycle_length')
self.min_rate = state_dict.pop('min_rate')
self.t_mult = state_dict.pop('t_mult')
super().load_state_dict(state_dict)
self.warmup_steps = state_dict.pop("warmup_steps")
self.cycle_length = state_dict.pop("cycle_length")
self.min_rate = state_dict.pop("min_rate")
self.t_mult = state_dict.pop("t_mult")
super().load_state_dict(state_dict)

View File

@ -20,7 +20,7 @@ def unwrap_model(model: nn.Module) -> nn.Module:
def create_ref_model(model: nn.Module) -> nn.Module:
"""Create a reference model for DPO/GRPO training.
Handles DDP-wrapped models safely by unwrapping first,
then creating a deep copy with frozen gradients.
"""
@ -37,25 +37,27 @@ def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
def get_logprobs(
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
input_ids: Tensor,
mask: Tensor,
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
input_ids: Tensor,
mask: Tensor,
reduction: str,
):
"""Compute token-wise log probabilities from model outputs.
Args:
model: The language model
input_ids: Input token IDs of shape [batch_size, seq_len]
mask: Attention mask of shape [batch_size, seq_len]
reduction: How to reduce over sequence dimension ("mean", "sum", "none")
Returns:
Log probabilities with reduction applied over sequence dimension
"""
allowed_reductions = ["mean", "sum", "none"]
if reduction not in allowed_reductions:
raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'")
raise ValueError(
f"reduction must be one of {allowed_reductions}, got '{reduction}'"
)
shifted_input_ids = input_ids[:, 1:]
shifted_mask = mask[:, 1:]
@ -64,13 +66,13 @@ def get_logprobs(
log_probs = torch.log_softmax(logits.float(), dim=-1)
token_logprobs = torch.gather(
log_probs,
dim=-1,
index=shifted_input_ids.unsqueeze(-1)
log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
).squeeze(-1)
if reduction == "mean":
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(dim=-1).clamp(min=1.0)
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(
dim=-1
).clamp(min=1.0)
elif reduction == "sum":
return (token_logprobs * shifted_mask).sum(dim=-1)
else:
@ -79,23 +81,25 @@ def get_logprobs(
class BaseStrategy(ABC):
"""Abstract base class for training strategies."""
def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str):
def __init__(
self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str
):
self.model = model
self.device = device
@abstractmethod
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
"""Compute loss for the given batch.
Args:
batch: Dictionary containing batch tensors
Returns:
Computed loss tensor
"""
raise NotImplementedError
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
"""Allow calling strategy directly as a callable."""
return self.compute_loss(batch)
@ -103,51 +107,55 @@ class BaseStrategy(ABC):
class StrategyFactory:
"""Factory class for creating training strategy instances.
Supports decorator-based registration for extensible strategy types.
All default strategies (seq, sft, dpo, grpo) are automatically registered.
Example usage:
@StrategyFactory.register("custom")
class CustomStrategy(BaseStrategy):
...
strategy = StrategyFactory.create(model, "custom", device)
"""
SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"})
STRATEGY_MAP: Dict[str, type] = {}
@classmethod
def register(cls, name: str):
"""Decorator to register a new strategy class.
Args:
name: Registration name for the strategy
Returns:
Decorator function that registers the strategy class
"""
def decorator(strategy_cls: type) -> type:
if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
raise TypeError(
f"{strategy_cls.__name__} must inherit from BaseStrategy"
)
cls.STRATEGY_MAP[name] = strategy_cls
return strategy_cls
return decorator
@classmethod
def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy:
"""Create a strategy instance based on training type.
Args:
model: Model instance for the strategy
train_type: Type of training ("seq", "sft", "dpo", "grpo")
device: Device to run the strategy on
**kwargs: Additional arguments passed to strategy constructor
Returns:
Strategy instance
Raises:
ValueError: If train_type is not supported
NotImplementedError: If train_type is in supported list but not implemented
@ -157,15 +165,15 @@ class StrategyFactory:
f"Unknown training strategy: '{train_type}'. "
f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}"
)
if train_type not in cls.STRATEGY_MAP:
raise NotImplementedError(
f"Strategy '{train_type}' is supported but not yet implemented."
)
strategy_cls = cls.STRATEGY_MAP[train_type]
return strategy_cls(model, device, **kwargs)
@classmethod
def available_strategies(cls) -> list:
"""Return list of registered strategy names."""
@ -179,77 +187,81 @@ class StrategyFactory:
@StrategyFactory.register("seq")
class SEQStrategy(BaseStrategy):
"""Standard next-token prediction training strategy.
Computes cross-entropy loss for next token prediction.
"""
def __init__(self, model, device, label_smoothing: float = 0.0):
super().__init__(model, device)
self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
logits = self.model(input_ids=input_ids)["logits"]
loss = F.cross_entropy(
input=logits.flatten(0, 1).float(),
target=target_ids.flatten(),
label_smoothing=self.label_smoothing
label_smoothing=self.label_smoothing,
)
return loss
@StrategyFactory.register("sft")
class SFTStrategy(BaseStrategy):
"""Supervised Fine-tuning strategy with loss masking.
Applies cross-entropy loss only to tokens where loss_mask is True.
"""
def __init__(self, model, device, label_smoothing: float = 0.0):
super().__init__(model, device)
self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
input_ids, target_ids, loss_mask = batch["input_ids"], batch["target_ids"], batch["loss_mask"]
input_ids, target_ids, loss_mask = (
batch["input_ids"],
batch["target_ids"],
batch["loss_mask"],
)
ignore_index = -100
logits = self.model(input_ids=input_ids)["logits"]
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
loss = F.cross_entropy(
input=logits.flatten(0, 1).float(),
target=target_ids.flatten(),
ignore_index=ignore_index,
label_smoothing=self.label_smoothing
label_smoothing=self.label_smoothing,
)
return loss
@StrategyFactory.register("dpo")
class DPOStrategy(BaseStrategy):
"""Direct Preference Optimization strategy.
Implements the DPO loss from the paper "Direct Preference Optimization".
Uses a reference model to compute KL divergence penalty.
"""
def __init__(
self,
model: nn.Module,
device: str,
beta: float = 0.1,
reduction: str = "mean",
):
self,
model: nn.Module,
device: str,
beta: float = 0.1,
reduction: str = "mean",
):
super().__init__(model, device)
self.ref_model = create_ref_model(model)
self.beta = beta
self.reduction = reduction
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
@ -257,17 +269,19 @@ class DPOStrategy(BaseStrategy):
contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction)
with torch.no_grad():
log_ref = get_logprobs(self.ref_model, contact_ids, contact_mask, self.reduction)
log_pi_chosen = log_pi[:chosen_ids.shape[0]]
log_pi_rejected = log_pi[chosen_ids.shape[0]:]
log_ref_chosen = log_ref[:chosen_ids.shape[0]]
log_ref_rejected = log_ref[chosen_ids.shape[0]:]
log_ref = get_logprobs(
self.ref_model, contact_ids, contact_mask, self.reduction
)
log_pi_chosen = log_pi[: chosen_ids.shape[0]]
log_pi_rejected = log_pi[chosen_ids.shape[0] :]
log_ref_chosen = log_ref[: chosen_ids.shape[0]]
log_ref_rejected = log_ref[chosen_ids.shape[0] :]
pi_log_ratio = log_pi_chosen - log_pi_rejected
ref_log_ratio = log_ref_chosen - log_ref_rejected
@ -280,14 +294,14 @@ class DPOStrategy(BaseStrategy):
@StrategyFactory.register("grpo")
class GRPOStrategy(BaseStrategy):
"""Group Relative Policy Optimization strategy.
Implements GRPO with clipping and KL penalty.
"""
def __init__(
self,
model: nn.Module,
device: str,
self,
model: nn.Module,
device: str,
clip_eps: float = 0.2,
kl_coef: float = 0.01,
group_size: int = 4,
@ -299,43 +313,47 @@ class GRPOStrategy(BaseStrategy):
self.kl_coef = kl_coef
self.group_size = group_size
self.reduction = reduction
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
prompts = batch["prompts"]
responses = batch["responses"]
masks = batch["masks"]
rewards = batch["rewards"]
batch_size, group_size, response_len = responses.shape
responses_flat = responses.view(-1, response_len)
masks_flat = masks.view(-1, response_len)
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
# Shape: (batch_size * group_size, seq_len + response_len)
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
log_probs_policy = get_logprobs(self.model, full_sequences, full_masks, self.reduction)
log_probs_policy = get_logprobs(
self.model, full_sequences, full_masks, self.reduction
)
log_probs_policy = log_probs_policy.view(batch_size, group_size)
with torch.no_grad():
log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction)
log_probs_ref = get_logprobs(
self.ref_model, full_sequences, full_masks, self.reduction
)
log_probs_ref = log_probs_ref.view(batch_size, group_size)
# Compute advantages from rewards with normalization
eps = torch.finfo(log_probs_policy.dtype).eps
mean = rewards.mean(dim=-1, keepdim=True)
std = rewards.std(dim=-1, keepdim=True)
advantages = (rewards - mean) / (std + eps)
# PPO-style clipped surrogate objective
ratio = torch.exp(0) # Off-policy: policy_model = old_model
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
kl_penalty = self.kl_coef * (log_probs_policy - log_probs_ref).square().mean()
total_loss = policy_loss + kl_penalty
return total_loss

View File

@ -18,52 +18,53 @@ from khaosz.trainer.metric_util import (
ctx_get_grad_norm,
ctx_get_grad_mean,
ctx_get_grad_std,
ctx_get_grad_nan_num
ctx_get_grad_nan_num,
)
from khaosz.data.serialization import Checkpoint
from khaosz.trainer.train_context import TrainContext
class TrainCallback(Protocol):
"""
"""
Callback interface for trainer.
"""
def on_train_begin(self, context: TrainContext):
""" Called at the beginning of training. """
"""Called at the beginning of training."""
def on_train_end(self, context: TrainContext):
""" Called at the end of training. """
"""Called at the end of training."""
def on_epoch_begin(self, context: TrainContext):
""" Called at the beginning of each epoch. """
"""Called at the beginning of each epoch."""
def on_epoch_end(self, context: TrainContext):
""" Called at the end of each epoch. """
"""Called at the end of each epoch."""
def on_step_begin(self, context: TrainContext):
""" Called at the beginning of each step. """
"""Called at the beginning of each step."""
def on_step_end(self, context: TrainContext):
""" Called at the end of each step."""
"""Called at the end of each step."""
def on_batch_begin(self, context: TrainContext):
""" Called at the beginning of each batch. """
"""Called at the beginning of each batch."""
def on_batch_end(self, context: TrainContext):
""" Called at the end of each batch. """
"""Called at the end of each batch."""
def on_error(self, context: TrainContext):
""" Called when an error occurs during training. """
"""Called when an error occurs during training."""
class GradientClippingCallback(TrainCallback):
"""
"""
Gradient clipping callback for trainer.
"""
def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm
def on_step_begin(self, context: TrainContext):
_ = context
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
@ -73,86 +74,95 @@ class SchedulerCallback(TrainCallback):
"""
Scheduler callback for trainer.
"""
def __init__(self):
pass
def on_train_begin(self, context: TrainContext):
for group in context.optimizer.param_groups:
if "initial_lr" not in group:
group["initial_lr"] = group["lr"]
group["initial_lr"] = group["lr"]
def on_batch_end(self, context: TrainContext):
if context.scheduler:
context.scheduler.step()
class CheckpointCallback(TrainCallback):
"""
"""
Checkpoint callback for trainer.
"""
def __init__(
self,
save_dir: str,
self,
save_dir: str,
interval: int,
weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
):
self.save_dir = save_dir
self.interval = interval
self.weight_only = weight_only
self.state_dict_fn = state_dict_fn
self.last_ckpt_iter = 0
@only_on_rank(0)
def _save_checkpoint(self, context: TrainContext):
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict()
save_path = os.path.join(
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
state_dict = (
self.state_dict_fn(context.model)
if self.state_dict_fn
else context.model.state_dict()
)
context.checkpoint = Checkpoint(
state_dict=state_dict,
epoch=context.epoch,
iteration=context.iteration
state_dict=state_dict, epoch=context.epoch, iteration=context.iteration
)
context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration
def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval:
self._save_checkpoint(context)
def on_train_end(self, context: TrainContext):
if context.iteration != self.last_ckpt_iter:
self._save_checkpoint(context)
def on_error(self, context: TrainContext):
self._save_checkpoint(context)
class ProgressBarCallback(TrainCallback):
"""
"""
Progress bar callback for trainer.
"""
def __init__(self, num_epoch: int):
self.num_epoch = num_epoch
self.progress_bar: tqdm = None
@only_on_rank(0)
def on_epoch_begin(self, context: TrainContext):
self.progress_bar = tqdm(
context.dataloader,
desc=f"Epoch {context.epoch+1}/{self.num_epoch}",
dynamic_ncols=True
context.dataloader,
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
dynamic_ncols=True,
)
@only_on_rank(0)
def on_batch_end(self, context: TrainContext):
self.progress_bar.set_postfix({
"loss": f"{context.loss:.4f}",
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}"
})
self.progress_bar.set_postfix(
{
"loss": f"{context.loss:.4f}",
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
}
)
self.progress_bar.update(1)
@only_on_rank(0)
def on_epoch_end(self, context: TrainContext):
_ = context
@ -162,66 +172,65 @@ class ProgressBarCallback(TrainCallback):
class MetricLoggerCallback(TrainCallback):
def __init__(
self,
log_dir:str,
save_interval:int,
log_interval:int=10,
metrics:List[str]=None
self,
log_dir: str,
save_interval: int,
log_interval: int = 10,
metrics: List[str] = None,
):
self.last_log_iter = 0
self.save_interval = save_interval
self.log_interval = log_interval
self.metrics = metrics or ['loss', 'lr']
self.metrics = metrics or ["loss", "lr"]
self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs"
self.log_dir.mkdir(parents=True, exist_ok=True)
self.log_cache = []
self._metric_funcs = {
'loss': ctx_get_loss,
'lr': ctx_get_lr,
'grad_norm': ctx_get_grad_norm,
'grad_std': ctx_get_grad_std,
'grad_max': ctx_get_grad_max,
'grad_min': ctx_get_grad_min,
'grad_mean': ctx_get_grad_mean,
'grad_nan_num': ctx_get_grad_nan_num
"loss": ctx_get_loss,
"lr": ctx_get_lr,
"grad_norm": ctx_get_grad_norm,
"grad_std": ctx_get_grad_std,
"grad_max": ctx_get_grad_max,
"grad_min": ctx_get_grad_min,
"grad_mean": ctx_get_grad_mean,
"grad_nan_num": ctx_get_grad_nan_num,
}
def _get_log_data(self, context: TrainContext):
return {
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"epoch": context.epoch,
"iter": context.iteration,
**{m: self._metric_funcs[m](context) for m in self.metrics}
**{m: self._metric_funcs[m](context) for m in self.metrics},
}
@only_on_rank(0)
def _add_log(self, log_data):
self.log_cache.append(log_data)
@only_on_rank(0)
def _save_log(self, epoch, iter):
log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl"
with open(log_file, 'w') as f:
with open(log_file, "w") as f:
for log in self.log_cache:
f.write(json.dumps(log) + '\n')
f.write(json.dumps(log) + "\n")
def on_batch_end(self, context):
if context.iteration % self.log_interval == 0:
log_data = self._get_log_data(context)
self._add_log(log_data)
if context.iteration - self.last_log_iter >= self.save_interval:
self._save_log(context.epoch, context.iteration)
self.last_log_iter = context.iteration
def on_train_end(self, context):
if context.iteration != self.last_log_iter:
self._save_log(context.epoch, context.iteration)
def on_error(self, context):
self._save_log(context.epoch, context.iteration)

View File

@ -21,11 +21,11 @@ class TrainContext:
optimizer: Optimizer = field(default=None)
scheduler: LRScheduler = field(default=None)
checkpoint: Checkpoint = field(default=None)
epoch: int = field(default=0)
iteration: int = field(default=0)
loss: float = field(default=0.0)
world_size: int = field(default=1)
rank: int = field(default=0)
kwargs: dict = field(default_factory=dict)
@ -39,17 +39,17 @@ class TrainContextBuilder:
world_size=get_world_size(),
rank=get_rank(),
)
device = get_current_device()
self._context.model = self._context.model.to(device=device)
if self.config.nprocs > 1:
fn = self.config.parallel_wrapper
self._context.model = fn(self._context.model)
self._context.optimizer = self.config.optimizer_fn(self._context.model)
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
if checkpoint is None:
checkpoint = Checkpoint(
@ -60,10 +60,10 @@ class TrainContextBuilder:
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
self._context.model.load_state_dict(checkpoint.state_dict)
self._context.checkpoint = checkpoint
return self
def with_dataloader(self) -> Self:
# fix: change batch level iteration to sample level offset
config = self.config
@ -72,28 +72,28 @@ class TrainContextBuilder:
data_source=config.dataset,
start_epoch=self._context.epoch,
start_iter=sampler_offset,
seed=config.random_seed
seed=config.random_seed,
)
dataloader = DataLoader(
config.dataset,
batch_size=config.batch_size,
config.dataset,
batch_size=config.batch_size,
sampler=resumeable_sampler,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
prefetch_factor=config.prefetch_factor
prefetch_factor=config.prefetch_factor,
)
self._context.dataloader = dataloader
return self
def with_strategy(self) -> Self:
self._context.strategy = StrategyFactory.create(
model=self._context.model,
train_type=self.config.strategy,
device=get_current_device(),
**self.config.extra_kwargs
**self.config.extra_kwargs,
)
return self
def build(self) -> TrainContext:
return self._context
return self._context

View File

@ -2,12 +2,12 @@ import logging
from typing import Optional, List
from khaosz.config import TrainConfig
from khaosz.trainer.train_callback import (
TrainCallback,
ProgressBarCallback,
TrainCallback,
ProgressBarCallback,
CheckpointCallback,
MetricLoggerCallback,
GradientClippingCallback,
SchedulerCallback
SchedulerCallback,
)
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
from khaosz.data.serialization import Checkpoint
@ -18,37 +18,39 @@ logger = logging.getLogger(__name__)
class Trainer:
def __init__(
self,
train_config: TrainConfig,
callbacks: Optional[List[TrainCallback]] = None
self, train_config: TrainConfig, callbacks: Optional[List[TrainCallback]] = None
):
self.train_config = train_config
default_callbacks = self._get_default_callbacks()
self.callbacks = default_callbacks + callbacks if callbacks else default_callbacks
self.callbacks = (
default_callbacks + callbacks if callbacks else default_callbacks
)
def _get_default_callbacks(self) -> List[TrainCallback]:
train_config = self.train_config
return [
ProgressBarCallback(train_config.n_epoch),
CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
MetricLoggerCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
CheckpointCallback(train_config.ckpt_dir, train_config.ckpt_interval),
MetricLoggerCallback(train_config.ckpt_dir, train_config.ckpt_interval),
GradientClippingCallback(train_config.max_grad_norm),
SchedulerCallback(),
]
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (TrainContextBuilder(self.train_config)
.with_checkpoint(checkpoint)
.with_dataloader()
.with_strategy()
.build())
return (
TrainContextBuilder(self.train_config)
.with_checkpoint(checkpoint)
.with_dataloader()
.with_strategy()
.build()
)
def _call_callbacks(self, method_name: str, context: TrainContext):
for callback in self.callbacks:
method = getattr(callback, method_name, None)
if method:
method(context)
def train(self, checkpoint: Optional[Checkpoint] = None):
config = self.train_config
spawn_parallel_fn(
@ -59,45 +61,45 @@ class Trainer:
master_port=config.master_port,
device_type=config.device_type,
device_ids=config.device_ids,
checkpoint=checkpoint
checkpoint=checkpoint,
)
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
context = self._build_context(checkpoint)
self._call_callbacks('on_train_begin', context)
self._call_callbacks("on_train_begin", context)
try:
context.model.train()
# 1.epoch
for epoch in range(context.epoch, self.train_config.n_epoch):
context.epoch = epoch
self._call_callbacks('on_epoch_begin', context)
self._call_callbacks("on_epoch_begin", context)
for batch in context.dataloader:
# 3. batch
self._call_callbacks('on_batch_begin', context)
self._call_callbacks("on_batch_begin", context)
loss = context.strategy(batch)
context.loss = loss.item()
context.iteration += 1
# to make the loss normalized by accumulation steps
stand_loss = loss / self.train_config.accumulation_steps
stand_loss.backward()
self._call_callbacks('on_batch_end', context)
self._call_callbacks("on_batch_end", context)
if context.iteration % self.train_config.accumulation_steps == 0:
# 2. step
self._call_callbacks('on_step_begin', context)
self._call_callbacks("on_step_begin", context)
context.optimizer.step()
context.optimizer.zero_grad()
self._call_callbacks('on_step_end', context)
self._call_callbacks("on_step_end", context)
self._call_callbacks("on_epoch_end", context)
self._call_callbacks('on_epoch_end', context)
except Exception as e:
logger.error(f"Training failed: {str(e)}", exc_info=True)
self._call_callbacks('on_error', context)
self._call_callbacks("on_error", context)
raise
finally:
self._call_callbacks('on_train_end', context)
self._call_callbacks("on_train_end", context)

View File

@ -26,7 +26,7 @@ classifiers = [
urls = { Homepage = "https://github.com/ViperEkura/KHAOSZ" }
[project.optional-dependencies]
dev = ["pytest==9.0.2"]
dev = ["pytest==9.0.2", "ruff"]
[tool.setuptools.packages.find]
where = ["."]
@ -35,4 +35,13 @@ where = ["."]
extra-index-url = "https://download.pytorch.org/whl/cu126"
[tool.setuptools.dynamic]
version = { attr = "khaosz.__version__" }
version = { attr = "khaosz.__version__" }
[tool.ruff]
target-version = "py312"
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"

View File

@ -17,14 +17,14 @@ class RandomDataset(Dataset):
self.length = length or int(np.random.randint(100, 200))
self.max_length = max_length
self.vocab_size = vocab_size
def __len__(self):
return self.length
def __getitem__(self, idx):
return {
"input_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
"target_ids": torch.randint(0, self.vocab_size, (self.max_length,))
"target_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
}
@ -33,10 +33,10 @@ class MultiTurnDataset(Dataset):
self.length = length or int(np.random.randint(100, 200))
self.max_length = max_length
self.vocab_size = vocab_size
def __len__(self):
return self.length
def __getitem__(self, idx):
input_ids = torch.randint(0, self.vocab_size, (self.max_length,))
target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
@ -54,18 +54,18 @@ class EarlyStoppingDataset(Dataset):
self.length = length
self.stop_after = stop_after
self.count = 0
def __len__(self):
return self.length
def __getitem__(self, idx):
self.count += 1
if self.count == self.stop_after:
raise RuntimeError("Simulated early stopping")
return {
"input_ids": torch.randint(0, 1000, (64,)),
"target_ids": torch.randint(0, 1000, (64,))
"target_ids": torch.randint(0, 1000, (64,)),
}
@ -74,10 +74,10 @@ def base_test_env(request: pytest.FixtureRequest):
func_name = request.function.__name__
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
config_path = os.path.join(test_dir, "config.json")
n_dim_choices = [8, 16, 32]
n_head_choices = [2, 4]
dim = int(np.random.choice(n_dim_choices))
n_heads = int(np.random.choice(n_head_choices))
n_kv_heads = n_heads // 2
@ -91,16 +91,16 @@ def base_test_env(request: pytest.FixtureRequest):
"dim_ffn": dim_ffn,
"max_len": 1024,
"n_layers": 4,
"norm_eps": 1e-5
"norm_eps": 1e-5,
}
with open(config_path, 'w') as f:
with open(config_path, "w") as f:
json.dump(config, f)
device = "cuda" if torch.cuda.is_available() else "cpu"
transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config).to(device=device)
tokenizer = BpeTokenizer()
yield {
"device": device,
"test_dir": str(test_dir),
@ -109,20 +109,23 @@ def base_test_env(request: pytest.FixtureRequest):
"model": model,
"tokenizer": tokenizer,
}
shutil.rmtree(test_dir)
@pytest.fixture
def random_dataset():
dataset = RandomDataset()
yield dataset
@pytest.fixture
def multi_turn_dataset():
dataset = MultiTurnDataset()
yield dataset
@pytest.fixture
def early_stopping_dataset():
dataset = EarlyStoppingDataset()
yield dataset
yield dataset

View File

@ -7,6 +7,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
from khaosz.data.serialization import Checkpoint
from khaosz.parallel.setup import get_rank, spawn_parallel_fn
def test_single_process():
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
@ -14,34 +15,31 @@ def test_single_process():
for epoch in range(3):
for iteration in range(10):
x = torch.randn(32, 10)
y = torch.randn(32, 5)
loss = model(x).mean()
loss.backward()
optimizer.step()
optimizer.zero_grad()
scheduler.step()
checkpoint = Checkpoint(
state_dict=model.state_dict(),
epoch=3,
iteration=30
)
checkpoint = Checkpoint(state_dict=model.state_dict(), epoch=3, iteration=30)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir)
loaded_checkpoint = Checkpoint.load(tmpdir)
assert loaded_checkpoint.epoch == 3
assert loaded_checkpoint.iteration == 30
def simple_training():
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=10)
for epoch in range(2):
for iteration in range(5):
x = torch.randn(16, 10)
@ -57,28 +55,23 @@ def simple_training():
epoch=2,
iteration=10,
)
rank = get_rank()
if rank == 0:
shared_dir = tempfile.mkdtemp()
checkpoint.save(shared_dir)
else:
shared_dir = None
if dist.is_initialized():
dir_list = [shared_dir]
dist.broadcast_object_list(dir_list, src=0)
shared_dir = dir_list[0]
loaded = Checkpoint.load(shared_dir)
assert loaded.epoch == 2
def test_multi_process():
spawn_parallel_fn(
simple_training,
world_size=2,
backend="gloo"
)
spawn_parallel_fn(simple_training, world_size=2, backend="gloo")

View File

@ -5,30 +5,32 @@ from khaosz.data.serialization import save_h5
from khaosz.data.dataset import *
def test_dataset_loader_random_paths(base_test_env):
"""Test dataset loader with multiple random paths"""
test_dir = base_test_env["test_dir"]
# Create multiple mmap dataset directories with random data
num_files = np.random.randint(2, 5)
for i in range(num_files):
seq_length = np.random.randint(200, 400)
dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64) for _ in range(10)],
"sequence": [
torch.randint(0, 1000, (seq_length,), dtype=torch.int64)
for _ in range(10)
],
}
save_h5(test_dir, f"data_{i}", dummy_data)
# Test loading with multiple paths
loaded_dataset = DatasetLoader.load(
train_type="seq",
load_path=test_dir,
window_size=64,
train_type="seq",
load_path=test_dir,
window_size=64,
)
assert loaded_dataset is not None
assert len(loaded_dataset) > 0
# Test that we can get items without errors
for i in range(len(loaded_dataset)):
item = loaded_dataset[i]
@ -41,30 +43,30 @@ def test_dataset_loader_random_paths(base_test_env):
def test_dpo_strategy_with_random_data(base_test_env):
"""Test DPO strategy with randomized preference data"""
test_dir = base_test_env["test_dir"]
# Create DPO-style data with memory mapping format
seq_length = np.random.randint(100, 200)
dummy_data = {
"chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"chosen_mask": [torch.ones(seq_length, dtype=torch.bool)],
"rejected_mask": [torch.ones(seq_length, dtype=torch.bool)]
"rejected_mask": [torch.ones(seq_length, dtype=torch.bool)],
}
save_h5(test_dir, "dpo_data", dummy_data)
# Load DPO dataset
dpo_dataset = DatasetLoader.load(
train_type="dpo",
load_path=test_dir,
window_size=64,
train_type="dpo",
load_path=test_dir,
window_size=64,
)
assert dpo_dataset is not None
assert hasattr(dpo_dataset, 'fetcher')
assert hasattr(dpo_dataset, "fetcher")
assert len(dpo_dataset) > 0
# Test that we can get DPO items without errors
for i in range(min(3, len(dpo_dataset))):
item = dpo_dataset[i]
@ -79,28 +81,28 @@ def test_dpo_strategy_with_random_data(base_test_env):
def test_sft_dataset_with_random_data(base_test_env):
"""Test SFT dataset with random data"""
test_dir = base_test_env["test_dir"]
# Create SFT-style data with memory mapping format
seq_length = np.random.randint(100, 200)
dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)]
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
}
save_h5(test_dir, "sft_data", dummy_data)
# Load SFT dataset
sft_dataset = DatasetLoader.load(
train_type="sft",
load_path=test_dir,
window_size=64,
train_type="sft",
load_path=test_dir,
window_size=64,
)
assert sft_dataset is not None
assert hasattr(sft_dataset, 'fetcher')
assert hasattr(sft_dataset, "fetcher")
assert len(sft_dataset) > 0
# Test that we can get SFT items without errors
for i in range(min(3, len(sft_dataset))):
item = sft_dataset[i]
@ -114,33 +116,30 @@ def test_sft_dataset_with_random_data(base_test_env):
def test_dataset_with_custom_stride(base_test_env):
"""Test dataset with custom stride parameter"""
test_dir = base_test_env["test_dir"]
# Create test data
seq_length = 200
dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
}
save_h5(test_dir,"stride_test_data", dummy_data)
save_h5(test_dir, "stride_test_data", dummy_data)
# Test with custom stride
custom_stride = 32
dataset = DatasetLoader.load(
train_type="seq",
load_path=test_dir,
window_size=64,
stride=custom_stride
train_type="seq", load_path=test_dir, window_size=64, stride=custom_stride
)
assert dataset is not None
assert len(dataset) > 0
# With stride 32 and window 64 on 200 length data, we should get more samples
# than with default stride (which equals window size)
default_stride_dataset = DatasetLoader.load(
train_type="seq",
load_path=test_dir,
window_size=64,
train_type="seq",
load_path=test_dir,
window_size=64,
)
assert len(dataset) > len(default_stride_dataset)

View File

@ -1,30 +1,32 @@
from khaosz.trainer import *
from khaosz.data import *
def test_random_sampler_consistency(random_dataset):
"""Test RandomSampler produces consistent results with same seed"""
dataset = random_dataset
# Create two samplers with same seed
sampler1 = ResumableDistributedSampler(dataset, seed=42)
sampler2 = ResumableDistributedSampler(dataset, seed=42)
indices1 = list(iter(sampler1))
indices2 = list(iter(sampler2))
assert indices1 == indices2
def test_random_sampler_different_seeds(random_dataset):
"""Test RandomSampler produces different results with different seeds"""
dataset = random_dataset
# Create two samplers with different seeds
sampler1 = ResumableDistributedSampler(dataset, seed=42)
sampler2 = ResumableDistributedSampler(dataset, seed=123)
indices1 = list(iter(sampler1))
indices2 = list(iter(sampler2))
# Very high probability they should be different
assert indices1 != indices2
@ -33,20 +35,20 @@ def test_sampler_across_epochs(random_dataset):
"""Test sampler behavior across multiple epochs"""
dataset = random_dataset
n = len(dataset)
sampler = ResumableDistributedSampler(dataset, seed=42)
# Get indices for first epoch
epoch1_indices = list(iter(sampler))
assert len(epoch1_indices) == n
# Get indices for second epoch
epoch2_indices = list(iter(sampler))
assert len(epoch2_indices) == n
# Check that epochs have different order (should be random)
assert epoch1_indices != epoch2_indices
# Check that all indices are present in each epoch
assert set(epoch1_indices) == set(range(n))
assert set(epoch2_indices) == set(range(n))
assert set(epoch2_indices) == set(range(n))

View File

@ -12,6 +12,7 @@ from khaosz.data import *
from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore
from tokenizers import pre_tokenizers
@pytest.fixture
def test_env(request: pytest.FixtureRequest):
func_name = request.function.__name__
@ -19,7 +20,7 @@ def test_env(request: pytest.FixtureRequest):
config_path = os.path.join(test_dir, "config.json")
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
model_path = os.path.join(test_dir, "model.safetensors")
config = {
"vocab_size": 1000,
"dim": 128,
@ -28,20 +29,20 @@ def test_env(request: pytest.FixtureRequest):
"dim_ffn": 256,
"max_len": 64,
"n_layers": 2,
"norm_eps": 1e-5
"norm_eps": 1e-5,
}
with open(config_path, 'w') as f:
with open(config_path, "w") as f:
json.dump(config, f)
tokenizer = BpeTokenizer()
sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
tokenizer.save(tokenizer_path)
transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config)
st.save_file(model.state_dict(), model_path)
yield {
"test_dir": test_dir,
"model": model,
@ -51,47 +52,55 @@ def test_env(request: pytest.FixtureRequest):
shutil.rmtree(test_dir)
def test_model_parameter(test_env):
save_dir = os.path.join(test_env["test_dir"], "save")
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])
model_param = ModelParameter(
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
)
ModelParameter.save(model_param, save_dir)
assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))
assert os.path.exists(os.path.join(save_dir, "config.json"))
# transformer
def test_transformer(test_env):
model = test_env["model"]
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size,
(4, test_env["transformer_config"].max_len))
input_ids = torch.randint(
0,
test_env["transformer_config"].vocab_size,
(4, test_env["transformer_config"].max_len),
)
output_logits = model(input_ids)["logits"]
target_shape = (4, test_env["transformer_config"].max_len, test_env["transformer_config"].vocab_size)
target_shape = (
4,
test_env["transformer_config"].max_len,
test_env["transformer_config"].vocab_size,
)
assert output_logits.shape == target_shape
# generator
def test_embedding_encoder_core(test_env):
parameter = ModelParameter(
test_env["model"],
test_env["tokenizer"],
test_env["transformer_config"]
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
)
encoder = EmbeddingEncoderCore(parameter)
single_emb = encoder.encode("测试文本")
assert isinstance(single_emb, torch.Tensor)
assert single_emb.shape[-1] == test_env["transformer_config"].dim
batch_emb = encoder.encode(["测试1", "测试2"])
assert isinstance(batch_emb, list)
assert len(batch_emb) == 2
def test_generator_core(test_env):
parameter = ModelParameter(
test_env["model"],
test_env["tokenizer"],
test_env["transformer_config"]
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
)
generator = GeneratorCore(parameter)
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10))
@ -102,8 +111,8 @@ def test_generator_core(test_env):
top_p=0.95,
attn_mask=None,
kv_caches=None,
start_pos=0
start_pos=0,
)
assert next_token_id.shape == (4, 1)
assert cache_increase == 10

View File

@ -13,7 +13,7 @@ def transformer_test_env():
"""创建Transformer测试专用环境"""
test_dir = tempfile.mkdtemp(prefix="transformer_test_")
config_path = os.path.join(test_dir, "config.json")
config = {
"vocab_size": 1000,
"dim": 128,
@ -22,18 +22,14 @@ def transformer_test_env():
"dim_ffn": 256,
"max_len": 64,
"n_layers": 2,
"norm_eps": 1e-5
"norm_eps": 1e-5,
}
with open(config_path, 'w') as f:
with open(config_path, "w") as f:
json.dump(config, f)
yield {
"test_dir": test_dir,
"config_path": config_path,
"config": config
}
yield {"test_dir": test_dir, "config_path": config_path, "config": config}
if os.path.exists(test_dir):
try:
for file in os.listdir(test_dir):
@ -46,74 +42,75 @@ def transformer_test_env():
def test_tie_weight_init(transformer_test_env):
config_path = transformer_test_env["config_path"]
config_data = transformer_test_env["config"].copy()
# case 1: tie weight
config_data["tie_weight"] = True
with open(config_path, 'w') as f:
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig().load(config_path)
model = Transformer(config)
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
original_weight = model.embed_tokens.weight.clone()
model.embed_tokens.weight.data[0, 0] = 100.0
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert not torch.equal(model.lm_head.weight, original_weight)
# case 2: not tie weight
config_data["tie_weight"] = False
with open(config_path, 'w') as f:
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig().load(config_path)
model = Transformer(config)
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
original_weight = model.embed_tokens.weight.clone()
model.embed_tokens.weight.data[0, 0] = 100.0
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert not torch.equal(model.lm_head.weight, original_weight)
def test_model_save_load_with_tie_weight(transformer_test_env):
test_dir = transformer_test_env["test_dir"]
model_path = os.path.join(test_dir, "model.safetensors")
config_data = transformer_test_env["config"].copy()
# case 1: tie weight
config_data["tie_weight"] = True
config_path = os.path.join(test_dir, "config.json")
with open(config_path, 'w') as f:
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig().load(config_path)
original_model = Transformer(config)
st.save_file(original_model.state_dict(), model_path)
loaded_config = ModelConfig().load(config_path)
model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path))
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
assert "lm_head.weight" not in model.state_dict()
# case 2: not tie weight (form tie-weight state dict load)
config_data["tie_weight"] = False
with open(config_path, 'w') as f:
with open(config_path, "w") as f:
json.dump(config_data, f)
loaded_config = ModelConfig().load(config_path)
model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path))
@ -121,4 +118,3 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
assert "lm_head.weight" in model.state_dict()

View File

@ -1,16 +1,14 @@
import torch
import torch.distributed as dist
from khaosz.parallel import (
get_rank,
only_on_rank,
spawn_parallel_fn
)
from khaosz.parallel import get_rank, only_on_rank, spawn_parallel_fn
@only_on_rank(0)
def _test_only_on_rank_helper():
return True
def only_on_rank():
result = _test_only_on_rank_helper()
if get_rank() == 0:
@ -18,22 +16,17 @@ def only_on_rank():
else:
assert result is None
def all_reduce():
x = torch.tensor([get_rank()], dtype=torch.int)
dist.all_reduce(x, op=dist.ReduceOp.SUM)
expected_sum = sum(range(dist.get_world_size()))
assert x.item() == expected_sum
def test_spawn_only_on_rank():
spawn_parallel_fn(
only_on_rank,
world_size=2,
backend="gloo"
)
spawn_parallel_fn(only_on_rank, world_size=2, backend="gloo")
def test_spawn_all_reduce():
spawn_parallel_fn(
all_reduce,
world_size=2,
backend="gloo"
)
spawn_parallel_fn(all_reduce, world_size=2, backend="gloo")

View File

@ -3,57 +3,48 @@ import torch
from khaosz.config import *
from khaosz.trainer import *
def test_callback_integration(base_test_env, random_dataset):
"""Test that all callbacks are properly integrated"""
schedule_config = CosineScheduleConfig(
warmup_steps=10,
total_steps=20
)
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig(
model=base_test_env["model"],
strategy='seq',
strategy="seq",
dataset=random_dataset,
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
checkpoint_dir=base_test_env["test_dir"],
ckpt_dir=base_test_env["test_dir"],
n_epoch=1,
batch_size=2,
checkpoint_interval=3,
ckpt_interval=3,
accumulation_steps=1,
max_grad_norm=1.0,
random_seed=42,
device_type=base_test_env["device"]
device_type=base_test_env["device"],
)
# Create custom callbacks to track calls
callback_calls = []
class TrackingCallback(TrainCallback):
def on_train_begin(self, context):
callback_calls.append('on_train_begin')
def on_batch_end(self, context):
callback_calls.append('on_batch_end')
def on_epoch_end(self, context):
callback_calls.append('on_epoch_end')
callback_calls.append("on_train_begin")
def on_batch_end(self, context):
callback_calls.append("on_batch_end")
def on_epoch_end(self, context):
callback_calls.append("on_epoch_end")
trainer = Trainer(train_config, callbacks=[TrackingCallback()])
trainer = Trainer(
train_config,
callbacks=[TrackingCallback()]
)
trainer.train()
# Verify callbacks were called
assert 'on_train_begin' in callback_calls
assert 'on_batch_end' in callback_calls
assert 'on_epoch_end' in callback_calls
assert "on_train_begin" in callback_calls
assert "on_batch_end" in callback_calls
assert "on_epoch_end" in callback_calls

View File

@ -5,31 +5,32 @@ from khaosz.config import *
from khaosz.trainer import *
from khaosz.data.serialization import Checkpoint
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
"""Simulate early stopping behavior"""
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig(
strategy="seq",
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
model=base_test_env["model"],
dataset=early_stopping_dataset,
checkpoint_dir=base_test_env["test_dir"],
ckpt_dir=base_test_env["test_dir"],
n_epoch=2,
batch_size=2,
checkpoint_interval=1,
ckpt_interval=1,
accumulation_steps=2,
random_seed=np.random.randint(1e4),
device_type=base_test_env["device"]
device_type=base_test_env["device"],
)
trainer = Trainer(train_config)
# Should handle early stopping gracefully
checkpoint = None
try:
@ -37,11 +38,11 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
except Exception:
# Handle any exceptions
pass
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
checkpoint = Checkpoint.load(load_dir)
trainer.train(checkpoint)
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
checkpoint = Checkpoint.load(load_dir)
assert checkpoint.iteration == 10
assert checkpoint.iteration == 10

View File

@ -9,39 +9,41 @@ from khaosz.data.dataset import *
def test_schedule_factory_random_configs():
"""Test scheduler factory with random configurations"""
# Create a simple model and optimizer for testing
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
# Test multiple random configurations
for _ in range(5): # Test 5 random configurations
schedule_configs = [
CosineScheduleConfig(
warmup_steps=np.random.randint(50, 200),
total_steps=np.random.randint(1000, 5000),
min_rate=np.random.uniform(0.01, 0.1)
min_rate=np.random.uniform(0.01, 0.1),
),
SGDRScheduleConfig(
warmup_steps=np.random.randint(50, 200),
cycle_length=np.random.randint(500, 2000),
t_mult=np.random.randint(1, 3),
min_rate=np.random.uniform(0.01, 0.1)
)
min_rate=np.random.uniform(0.01, 0.1),
),
]
for config in schedule_configs:
# Validate configuration
config.validate()
# Create scheduler using factory
scheduler = SchedulerFactory.load(optimizer, config)
# Verify scheduler type
if isinstance(config, CosineScheduleConfig):
assert isinstance(scheduler, CosineScheduler)
assert scheduler.warmup_steps == config.warmup_steps
assert scheduler.lr_decay_steps == config.total_steps - config.warmup_steps
assert (
scheduler.lr_decay_steps == config.total_steps - config.warmup_steps
)
assert scheduler.min_rate == config.min_rate
elif isinstance(config, SGDRScheduleConfig):
assert isinstance(scheduler, SGDRScheduler)
@ -49,17 +51,17 @@ def test_schedule_factory_random_configs():
assert scheduler.cycle_length == config.cycle_length
assert scheduler.t_mult == config.t_mult
assert scheduler.min_rate == config.min_rate
# Test scheduler state dict functionality
state_dict = scheduler.state_dict()
assert 'warmup_steps' in state_dict
assert 'min_rate' in state_dict
assert "warmup_steps" in state_dict
assert "min_rate" in state_dict
# Test scheduler step functionality
initial_lr = scheduler.get_last_lr()
scheduler.step()
new_lr = scheduler.get_last_lr()
# Learning rate should change after step, or if it's the first step,
# the epoch counter should increment
assert initial_lr != new_lr or scheduler.last_epoch > -1
@ -67,10 +69,10 @@ def test_schedule_factory_random_configs():
def test_schedule_factory_edge_cases():
"""Test scheduler factory with edge cases and boundary conditions"""
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
# Test edge cases for CosineScheduleConfig
edge_cases = [
# Minimal warmup and steps
@ -80,12 +82,12 @@ def test_schedule_factory_edge_cases():
# Zero min_rate (edge case)
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.0),
]
for config in edge_cases:
config.validate()
scheduler = SchedulerFactory.load(optimizer, config)
assert scheduler is not None
# Test multiple steps
for _ in range(10):
scheduler.step()
@ -93,7 +95,7 @@ def test_schedule_factory_edge_cases():
def test_schedule_factory_invalid_configs():
"""Test scheduler factory with invalid configurations"""
# Test invalid configurations that should raise errors
invalid_configs = [
# Negative warmup steps
@ -104,7 +106,7 @@ def test_schedule_factory_invalid_configs():
{"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1},
{"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1},
]
for kwargs in invalid_configs:
with pytest.raises(ValueError):
config = CosineScheduleConfig(**kwargs)
@ -113,24 +115,24 @@ def test_schedule_factory_invalid_configs():
def test_schedule_factory_state_persistence():
"""Test scheduler state persistence (save/load)"""
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1)
scheduler = SchedulerFactory.load(optimizer, config)
# Take a few steps
for _ in range(5):
scheduler.step()
# Save state
state_dict = scheduler.state_dict()
# Create new scheduler and load state
new_scheduler = SchedulerFactory.load(optimizer, config)
new_scheduler.load_state_dict(state_dict)
# Verify states match
assert scheduler.last_epoch == new_scheduler.last_epoch
assert scheduler.get_last_lr() == new_scheduler.get_last_lr()
assert scheduler.get_last_lr() == new_scheduler.get_last_lr()

View File

@ -6,100 +6,94 @@ from khaosz.config import *
from khaosz.trainer import *
from khaosz.data.dataset import *
def test_different_batch_sizes(base_test_env, random_dataset):
"""Test training with different batch sizes"""
batch_sizes = [1, 2, 4, 8]
for batch_size in batch_sizes:
schedule_config = CosineScheduleConfig(
warmup_steps=10,
total_steps=20
)
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig(
strategy="seq",
model=base_test_env["model"],
dataset=random_dataset,
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
checkpoint_dir=base_test_env["test_dir"],
ckpt_dir=base_test_env["test_dir"],
n_epoch=1,
batch_size=batch_size,
checkpoint_interval=5,
ckpt_interval=5,
accumulation_steps=1,
max_grad_norm=1.0,
random_seed=np.random.randint(1000),
device_type=base_test_env["device"]
device_type=base_test_env["device"],
)
assert train_config.batch_size == batch_size
def test_gradient_accumulation(base_test_env, random_dataset):
"""Test training with different gradient accumulation steps"""
accumulation_steps_list = [1, 2, 4]
for accumulation_steps in accumulation_steps_list:
schedule_config = CosineScheduleConfig(
warmup_steps=10,
total_steps=20
)
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig(
strategy="seq",
model=base_test_env["model"],
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
dataset=random_dataset,
checkpoint_dir=base_test_env["test_dir"],
ckpt_dir=base_test_env["test_dir"],
n_epoch=1,
batch_size=2,
checkpoint_interval=10,
ckpt_interval=10,
accumulation_steps=accumulation_steps,
max_grad_norm=1.0,
random_seed=42,
device_type=base_test_env["device"]
device_type=base_test_env["device"],
)
trainer = Trainer(train_config)
trainer.train()
assert train_config.accumulation_steps == accumulation_steps
def test_memory_efficient_training(base_test_env, random_dataset):
"""Test training with memory-efficient configurations"""
# Test with smaller batch sizes and gradient checkpointing
small_batch_configs = [
{"batch_size": 1, "accumulation_steps": 8},
{"batch_size": 2, "accumulation_steps": 4},
{"batch_size": 4, "accumulation_steps": 2}
{"batch_size": 4, "accumulation_steps": 2},
]
for config in small_batch_configs:
schedule_config = CosineScheduleConfig(
warmup_steps=10,
total_steps=20
)
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig(
strategy="seq",
model=base_test_env["model"],
dataset=random_dataset,
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
checkpoint_dir=base_test_env["test_dir"],
ckpt_dir=base_test_env["test_dir"],
n_epoch=1,
batch_size=config["batch_size"],
checkpoint_interval=5,
ckpt_interval=5,
accumulation_steps=config["accumulation_steps"],
max_grad_norm=1.0,
random_seed=42,
device_type=base_test_env["device"]
device_type=base_test_env["device"],
)
assert train_config.accumulation_steps == config["accumulation_steps"]
assert train_config.accumulation_steps == config["accumulation_steps"]

View File

@ -17,41 +17,47 @@ class GenerationBenchmark:
self,
config: ModelConfig,
device: str = "cuda",
dtype: torch.dtype = torch.float16
dtype: torch.dtype = torch.float16,
):
self.config = config
self.device = device
self.dtype = dtype
self.model = Transformer(config).to(device=device, dtype=dtype)
self.model.eval()
def _initialize_kv_cache(self, batch_size: int) -> list:
"""初始化KV缓存"""
config = self.config
shape = (batch_size, config.max_len, config.n_layers, config.n_kv_heads, config.dim // config.n_heads)
shape = (
batch_size,
config.max_len,
config.n_layers,
config.n_kv_heads,
config.dim // config.n_heads,
)
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
return (k_cache, v_cache)
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
prompt_ids = torch.randint(
low=0,
high=self.config.vocab_size,
size=(batch_size, prompt_length),
device=self.device,
dtype=torch.long
dtype=torch.long,
)
gen_ids = torch.randint(
low=0,
high=self.config.vocab_size,
size=(batch_size, total_length - prompt_length),
device=self.device,
dtype=torch.long
dtype=torch.long,
)
return prompt_ids, gen_ids
@torch.inference_mode()
def run_prefill_benchmark(
self,
@ -59,32 +65,38 @@ class GenerationBenchmark:
prompt_length: int = 512,
num_trials: int = 10,
) -> BenchmarkResult:
for _ in range(3):
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length)
prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length
)
_ = self.model(prompt_ids)
torch.cuda.synchronize()
total_time = 0.0
total_tokens = batch_size * prompt_length * num_trials
for trial in range(num_trials):
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length)
prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
_ = self.model(prompt_ids)
end_event.record()
torch.cuda.synchronize()
trial_time = start_event.elapsed_time(end_event) / 1000
total_time += trial_time
print(f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
f"({prompt_length / trial_time:.1f} tokens/s)")
print(
f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
f"({prompt_length / trial_time:.1f} tokens/s)"
)
return BenchmarkResult(
total_tokens=total_tokens,
total_time=total_time,
@ -95,9 +107,9 @@ class GenerationBenchmark:
"prompt_length": prompt_length,
"dtype": self.dtype,
"device": self.device,
}
},
)
@torch.inference_mode()
def run_decoding_benchmark(
self,
@ -106,39 +118,43 @@ class GenerationBenchmark:
gen_length: int = 128,
num_trials: int = 5,
) -> BenchmarkResult:
total_time = 0.0
total_tokens = batch_size * gen_length * num_trials
for trial in range(num_trials):
prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length)
prompt_ids, gen_ids = self._prepare_inputs(
batch_size, prompt_length, prompt_length + gen_length
)
kv_cache = self._initialize_kv_cache(batch_size)
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
current_pos = prompt_length
for i in range(gen_length):
input_token = gen_ids[:, i:i+1]
_ = self.model(input_token, persistent_key_values=kv_cache, start_pos=current_pos)
input_token = gen_ids[:, i : i + 1]
_ = self.model(
input_token, persistent_key_values=kv_cache, start_pos=current_pos
)
current_pos += 1
end_event.record()
torch.cuda.synchronize()
trial_time = start_event.elapsed_time(end_event) / 1000
total_time += trial_time
print(f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
f"({gen_length / trial_time:.1f} tokens/s)")
print(
f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
f"({gen_length / trial_time:.1f} tokens/s)"
)
return BenchmarkResult(
total_tokens=total_tokens,
total_time=total_time,
@ -150,24 +166,28 @@ class GenerationBenchmark:
"gen_length": gen_length,
"dtype": self.dtype,
"device": self.device,
}
},
)
def print_benchmark_result(result: BenchmarkResult):
"""打印基准测试结果"""
benchmark_type = result.metadata["benchmark_type"]
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
print(f"Total Tokens Processed: {result.total_tokens:,}")
print(f"Time Consumed: {result.total_time:.3f}s")
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
if benchmark_type == "prefill":
print(f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}")
print(
f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}"
)
elif benchmark_type == "decoding":
print(f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}")
print(
f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}"
)
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
print("-" * 80)
@ -183,16 +203,19 @@ if __name__ == "__main__":
n_layers=24,
norm_eps=1e-5,
)
benchmark = GenerationBenchmark(config)
print("=" * 80)
print("Running Transformer Generation Benchmark")
print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark(batch_size=4, prompt_length=512, num_trials=5)
prefill_result = benchmark.run_prefill_benchmark(
batch_size=4, prompt_length=512, num_trials=5
)
print_benchmark_result(prefill_result)
gen_result = benchmark.run_decoding_benchmark(batch_size=4, prompt_length=512, gen_length=128, num_trials=5)
gen_result = benchmark.run_decoding_benchmark(
batch_size=4, prompt_length=512, gen_length=128, num_trials=5
)
print_benchmark_result(gen_result)

View File

@ -21,10 +21,10 @@ def processor(
with disable_random_init():
param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16)
param.to(device="cuda", dtype=torch.bfloat16)
generator = BatchGenerator(param)
with open(input_json_file, "r", encoding='utf-8') as f:
with open(input_json_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f]
queries = [item[question_key] for item in input_data]
@ -41,26 +41,62 @@ def processor(
responses = generator.generate(request)
with open(output_json_file, "w", encoding='utf-8') as f:
with open(output_json_file, "w", encoding="utf-8") as f:
for query, response in zip(queries, responses):
output_item = {question_key: query, response_key: response}
f.write(json.dumps(output_item, ensure_ascii=False) + '\n')
f.write(json.dumps(output_item, ensure_ascii=False) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
parser.add_argument("--input_json_file", type=str, required=True, help="Path to the input JSONL file.")
parser.add_argument("--output_json_file", type=str, required=True, help="Path to the output JSONL file.")
parser.add_argument("--question_key", type=str, default="question", help="Key for the question in the input JSON.")
parser.add_argument("--response_key", type=str, default="response", help="Key for the response in the output JSON.")
parser.add_argument("--temperature", type=float, default=0.60, help="Temperature for generating responses.")
parser.add_argument("--top_k", type=int, default=30, help="Top-k value for generating responses.")
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.")
parser.add_argument(
"--model_dir", type=str, required=True, help="Path to the model directory."
)
parser.add_argument(
"--input_json_file",
type=str,
required=True,
help="Path to the input JSONL file.",
)
parser.add_argument(
"--output_json_file",
type=str,
required=True,
help="Path to the output JSONL file.",
)
parser.add_argument(
"--question_key",
type=str,
default="question",
help="Key for the question in the input JSON.",
)
parser.add_argument(
"--response_key",
type=str,
default="response",
help="Key for the response in the output JSON.",
)
parser.add_argument(
"--temperature",
type=float,
default=0.60,
help="Temperature for generating responses.",
)
parser.add_argument(
"--top_k", type=int, default=30, help="Top-k value for generating responses."
)
parser.add_argument(
"--top_p",
type=float,
default=0.95,
help="Top-p value for generating responses.",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for generating responses."
)
args = parser.parse_args()
with torch.inference_mode():
processor(**vars(args))
processor(**vars(args))

View File

@ -11,89 +11,99 @@ from khaosz.inference.core import disable_random_init
def compute_perplexity(
model: nn.Module,
input_ids: Tensor,
input_mask: Tensor,
) -> Tensor:
model: nn.Module,
input_ids: Tensor,
input_mask: Tensor,
) -> Tensor:
"""
Compute the perplexity of a batch of input sequences,
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
Compute the perplexity of a batch of input sequences,
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
"""
output = model(input_ids, input_mask)
logits = output["logits"]
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
loss = F.cross_entropy(
shifted_logits.flatten(0, 1),
shifted_input_ids.flatten(0, 1),
reduction='none'
shifted_logits.flatten(0, 1), shifted_input_ids.flatten(0, 1), reduction="none"
)
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
loss = loss * shifted_mask
sentence_loss = (loss).sum(dim=1) / shifted_mask.sum(dim=1)
perplexity = torch.exp(sentence_loss) # [batch_size]
perplexity = torch.exp(sentence_loss) # [batch_size]
return perplexity
def process_file(
model_dir: str,
input_file: str,
output_file: str,
batch_size: int,
text_key: str
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
):
with disable_random_init():
param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16)
param.to(device="cuda", dtype=torch.bfloat16)
model = param.model
tokenizer = param.tokenizer
with open(input_file, "r", encoding='utf-8') as f:
with open(input_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f]
texts = [item[text_key] for item in input_data]
encoded_texts = [tokenizer.encode(text) for text in texts]
output_data = []
for i in tqdm(range(0, len(encoded_texts), batch_size), desc="Computing perplexity"):
batch_encoded = encoded_texts[i:i + batch_size]
batch_texts = texts[i:i + batch_size]
for i in tqdm(
range(0, len(encoded_texts), batch_size), desc="Computing perplexity"
):
batch_encoded = encoded_texts[i : i + batch_size]
batch_texts = texts[i : i + batch_size]
max_len = max(len(seq) for seq in batch_encoded)
padded_ids = []
masks = []
for seq in batch_encoded:
pad_len = max_len - len(seq)
padded_seq = [tokenizer.pad_id] * pad_len + seq
mask = [False] * pad_len + [True] * len(seq)
padded_ids.append(padded_seq)
masks.append(mask)
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
perplexity = compute_perplexity(model, input_ids, input_mask)
for text, ppl in zip(batch_texts, perplexity):
output_data.append({text_key: text, "ppl": float(ppl.item())})
with open(output_file, "w", encoding='utf-8') as f:
with open(output_file, "w", encoding="utf-8") as f:
for item in output_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
f.write(json.dumps(item, ensure_ascii=False) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
parser.add_argument("--input_file", type=str, required=True, help="Path to the input file.")
parser.add_argument("--output_file", type=str, required=True, help="Path to the output file.")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation.")
parser.add_argument("--text_key", type=str, default="text", help="Key for the text field in the input data.")
parser.add_argument(
"--model_dir", type=str, required=True, help="Path to the model directory."
)
parser.add_argument(
"--input_file", type=str, required=True, help="Path to the input file."
)
parser.add_argument(
"--output_file", type=str, required=True, help="Path to the output file."
)
parser.add_argument(
"--batch_size", type=int, default=4, help="Batch size for evaluation."
)
parser.add_argument(
"--text_key",
type=str,
default="text",
help="Key for the text field in the input data.",
)
args = parser.parse_args()
with torch.inference_mode():

View File

@ -15,40 +15,130 @@ from khaosz.parallel import get_rank
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train the Transformer model.")
parser.add_argument("--train_type", type=str, required=True, choices=["seq", "sft", "dpo"], help="Train type.")
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.")
parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.")
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.")
parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.")
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory")
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.")
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.")
parser.add_argument(
"--train_type",
type=str,
required=True,
choices=["seq", "sft", "dpo"],
help="Train type.",
)
parser.add_argument(
"--data_root_path",
type=str,
required=True,
help="Path to the root directory of the dataset.",
)
parser.add_argument(
"--param_path",
type=str,
required=True,
help="Path to the model parameters or resume checkpoint.",
)
parser.add_argument(
"--n_epoch", type=int, default=1, help="Number of epochs to train."
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for training."
)
parser.add_argument(
"--accumulation_steps",
type=int,
default=1,
help="Number of iterations between each optimizer step.",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=1000,
help="Number of iters between warnings.",
)
parser.add_argument(
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
)
parser.add_argument(
"--max_grad_norm",
type=float,
default=1.0,
help="Max gradient norm for clipping.",
)
parser.add_argument(
"--adamw_beta1",
type=float,
default=0.9,
help="Beta values for AdamW optimizer.",
)
parser.add_argument(
"--adamw_beta2",
type=float,
default=0.95,
help="Beta values for AdamW optimizer.",
)
parser.add_argument(
"--adamw_weight_decay",
type=float,
default=0.01,
help="Weight decay for AdamW optimizer.",
)
parser.add_argument(
"--random_seed", type=int, default=3407, help="Random seed for reproducibility."
)
parser.add_argument(
"--num_workers", type=int, default=4, help="Number of workers for data loading."
)
parser.add_argument(
"--no_pin_memory",
action="store_false",
dest="pin_memory",
help="Disable pin memory",
)
parser.add_argument(
"--window_size",
type=int,
default=None,
help="the max length of the input sequence.",
)
parser.add_argument(
"--stride", type=int, default=None, help="the step size of the input sequence."
)
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
parser.add_argument("--label_smoothing", type=float, default=0.1, help="cross_entropy function label smoothing parameter")
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
parser.add_argument(
"--label_smoothing",
type=float,
default=0.1,
help="cross_entropy function label smoothing parameter",
)
parser.add_argument(
"--ckpt_interval",
type=int,
default=5000,
help="Number of iters between checkpoints.",
)
parser.add_argument(
"--ckpt_dir",
type=str,
default="checkpoint",
help="Directory to save checkpoints.",
)
parser.add_argument(
"--start_epoch", type=int, default=0, help="Start epoch for training."
)
parser.add_argument(
"--start_batch", type=int, default=0, help="Start batch for training."
)
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
parser.add_argument("--device_type", type=str, default="cuda", help="Device type to use.")
parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use."
)
args = parser.parse_args()
return args
def ddp_wrap(model: nn.Module):
local_rank = get_rank()
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
@ -56,16 +146,21 @@ def ddp_wrap(model: nn.Module):
model,
device_ids=[local_rank],
output_device=local_rank,
find_unused_parameters=False
find_unused_parameters=False,
)
return ddp_model
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), **kwargs)
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
def create_scheduler(
optimizer: optim.Optimizer, **kwargs
) -> optim.lr_scheduler.LRScheduler:
return SchedulerFactory.load(optimizer, **kwargs)
def prepare_checkpoint(model: nn.Module) -> dict:
return model.module.state_dict()
@ -81,8 +176,8 @@ def train(
start_batch: int,
accumulation_steps: int,
warmup_steps: int,
checkpoint_interval: int,
checkpoint_dir: str,
ckpt_interval: int,
ckpt_dir: str,
dpo_beta: float,
adamw_beta1: float,
adamw_beta2: float,
@ -99,48 +194,50 @@ def train(
):
assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path)
parameter = ModelParameter.load(param_path)
if window_size is None:
window_size = parameter.config.max_len
model = parameter.model
strategy_kwargs = {
"dpo_beta": dpo_beta,
"label_smoothing": label_smoothing
}
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
dataset = DatasetLoader.load(
train_type=train_type,
load_path=data_root_path,
window_size=window_size,
stride=stride
)
schedule_config = CosineScheduleConfig(
warmup_steps=warmup_steps,
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
stride=stride,
)
optimizer_fn = partial(create_optimizer,
**{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay})
scheduler_fn = partial(create_scheduler,
**{"schedule_config": schedule_config})
schedule_config = CosineScheduleConfig(
warmup_steps=warmup_steps,
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
)
optimizer_fn = partial(
create_optimizer,
**{
"lr": max_lr,
"betas": (adamw_beta1, adamw_beta2),
"weight_decay": adamw_weight_decay,
},
)
scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config})
train_config = TrainConfig(
model=model,
strategy=train_type,
dataset=dataset,
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
checkpoint_dir=checkpoint_dir,
ckpt_dir=ckpt_dir,
n_epoch=n_epoch,
batch_size=batch_size,
start_epoch=start_epoch,
start_batch=start_batch,
checkpoint_interval=checkpoint_interval,
ckpt_interval=ckpt_interval,
accumulation_steps=accumulation_steps,
max_grad_norm=max_grad_norm,
random_seed=random_seed,
@ -152,11 +249,11 @@ def train(
device_type=device_type,
extra_kwargs=strategy_kwargs,
)
trainer = Trainer(train_config)
trainer.train()
if __name__ == "__main__":
args = parse_args()
train(**vars(args))
train(**vars(args))