style: 使用ruff 工具优化代码风格

This commit is contained in:
ViperEkura 2026-03-30 23:32:28 +08:00
parent 345fd2f091
commit 426af2d75f
52 changed files with 1838 additions and 1495 deletions

27
.github/workflows/lint.yml vendored Normal file
View File

@ -0,0 +1,27 @@
name: Lint
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install dependencies
run: |
pip install --upgrade pip
pip install .[dev]
- name: Check formatting with ruff
run: |
ruff format --check .

View File

@ -1,17 +0,0 @@
name: Spell Check
on: [push, pull_request]
permissions:
contents: read
jobs:
spellcheck:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Check spelling in specific files
uses: codespell-project/actions-codespell@v2
with:
check_filenames: true
only_warn: false
path: "**/*.{md, py}"

31
.github/workflows/tests.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: Tests
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install --upgrade pip
pip install .[dev]
- name: Run tests with pytest
run: |
python -m pytest tests/ -v

3
.gitignore vendored
View File

@ -8,5 +8,8 @@
!*.py !*.py
!*.md !*.md
!*.png !*.png
!LICENSE !LICENSE
!pyproject.toml !pyproject.toml
!.github/workflows/lint.yml
!.github/workflows/tests.yml

View File

@ -54,8 +54,8 @@ python train.py \
--n_epoch=5 \ --n_epoch=5 \
--batch_size=8 \ --batch_size=8 \
--max_lr=2e-4 \ --max_lr=2e-4 \
--checkpoint_interval=10000 \ --ckpt_interval=10000 \
--checkpoint_dir=checkpoints --ckpt_dir=checkpoints
``` ```
**Parameter Explanation:** **Parameter Explanation:**
@ -67,8 +67,8 @@ python train.py \
- `--accumulation_steps`: Number of batches per training step - `--accumulation_steps`: Number of batches per training step
- `--warmup_steps`: Warmup steps - `--warmup_steps`: Warmup steps
- `--max_lr`: Maximum learning rate (using warmup + cosine decay) - `--max_lr`: Maximum learning rate (using warmup + cosine decay)
- `--checkpoint_interval`: Checkpoint saving interval - `--ckpt_interval`: Checkpoint saving interval
- `--checkpoint_dir`: Checkpoint saving directory - `--ckpt_dir`: Checkpoint saving directory
- `--resume_dir`: Resume training from specified path - `--resume_dir`: Resume training from specified path
@ -191,8 +191,8 @@ python train.py \
--n_epoch=5 \ --n_epoch=5 \
--batch_size=8 \ --batch_size=8 \
--max_lr=2e-4 \ --max_lr=2e-4 \
--checkpoint_interval=10000 \ --ckpt_interval=10000 \
--checkpoint_dir=checkpoints --ckpt_dir=checkpoints
``` ```
**参数说明:** **参数说明:**
@ -204,8 +204,8 @@ python train.py \
- `--accumulation_steps`: 每个训练步骤的 batch 数量 - `--accumulation_steps`: 每个训练步骤的 batch 数量
- `--warmup_steps`: 预热步数warmup steps - `--warmup_steps`: 预热步数warmup steps
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减) - `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
- `--checkpoint_interval`: 检查点保存间隔 - `--ckpt_interval`: 检查点保存间隔
- `--checkpoint_dir`: 检查点保存目录 - `--ckpt_dir`: 检查点保存目录
- `--resume_dir`: 从指定路径恢复训练 - `--resume_dir`: 从指定路径恢复训练

View File

@ -2,13 +2,12 @@ import os
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
PROJECT_ROOT = os.path.dirname( PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.path.dirname(os.path.abspath(__file__)))
if __name__ == "__main__": if __name__ == "__main__":
snapshot_download( snapshot_download(
repo_id="ViperEk/KHAOSZ", repo_id="ViperEk/KHAOSZ",
local_dir=os.path.join(PROJECT_ROOT, "params"), local_dir=os.path.join(PROJECT_ROOT, "params"),
force_download=True force_download=True,
) )

View File

@ -5,8 +5,8 @@ from khaosz.inference.core import disable_random_init
from khaosz.inference.generator import LoopGenerator, GenerationRequest from khaosz.inference.generator import LoopGenerator, GenerationRequest
PROJECT_ROOT = os.path.dirname( PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.path.dirname(os.path.abspath(__file__)))
def generate_text(): def generate_text():
@ -14,7 +14,7 @@ def generate_text():
model_dir = os.path.join(PROJECT_ROOT, "params") model_dir = os.path.join(PROJECT_ROOT, "params")
param = ModelParameter.load(model_dir) param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16) param.to(device="cuda", dtype=torch.bfloat16)
query = input(">> ") query = input(">> ")
request = GenerationRequest( request = GenerationRequest(
@ -31,5 +31,6 @@ def generate_text():
print(response) print(response)
if __name__ == "__main__": if __name__ == "__main__":
generate_text() generate_text()

View File

@ -4,17 +4,23 @@ from khaosz.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init from khaosz.inference.core import disable_random_init
from khaosz.inference.generator import BatchGenerator, GenerationRequest from khaosz.inference.generator import BatchGenerator, GenerationRequest
PROJECT_ROOT = os.path.dirname( PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.path.dirname(os.path.abspath(__file__)))
def batch_generate(): def batch_generate():
with disable_random_init(): with disable_random_init():
model_dir = os.path.join(PROJECT_ROOT, "params") model_dir = os.path.join(PROJECT_ROOT, "params")
param = ModelParameter.load(model_dir) param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16) param.to(device="cuda", dtype=torch.bfloat16)
generator = BatchGenerator(param) generator = BatchGenerator(param)
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"] inputs = [
"你好",
"请问什么是人工智能",
"今天天气如何",
"我感到焦虑, 请问我应该怎么办",
"请问什么是显卡",
]
request = GenerationRequest( request = GenerationRequest(
query=inputs, query=inputs,
@ -30,5 +36,6 @@ def batch_generate():
for q, r in zip(inputs, responses): for q, r in zip(inputs, responses):
print((q, r)) print((q, r))
if __name__ == "__main__": if __name__ == "__main__":
batch_generate() batch_generate()

View File

@ -5,8 +5,8 @@ from khaosz.inference.core import disable_random_init
from khaosz.inference.generator import StreamGenerator, GenerationRequest from khaosz.inference.generator import StreamGenerator, GenerationRequest
PROJECT_ROOT = os.path.dirname( PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.path.dirname(os.path.abspath(__file__)))
def chat(): def chat():
@ -14,7 +14,7 @@ def chat():
model_dir = os.path.join(PROJECT_ROOT, "params") model_dir = os.path.join(PROJECT_ROOT, "params")
param = ModelParameter.load(model_dir) param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16) param.to(device="cuda", dtype=torch.bfloat16)
generator = StreamGenerator(param) generator = StreamGenerator(param)
history = [] history = []

View File

@ -6,41 +6,30 @@ from khaosz.config import (
TrainConfig, TrainConfig,
) )
from khaosz.model.transformer import Transformer from khaosz.model.transformer import Transformer
from khaosz.data import ( from khaosz.data import DatasetLoader, BpeTokenizer
DatasetLoader,
BpeTokenizer
)
from khaosz.inference.generator import ( from khaosz.inference.generator import (
GenerationRequest, GenerationRequest,
LoopGenerator, LoopGenerator,
StreamGenerator, StreamGenerator,
BatchGenerator, BatchGenerator,
EmbeddingEncoder, EmbeddingEncoder,
GeneratorFactory GeneratorFactory,
)
from khaosz.trainer import (
Trainer,
StrategyFactory,
SchedulerFactory
) )
from khaosz.trainer import Trainer, StrategyFactory, SchedulerFactory
__all__ = [ __all__ = [
"Transformer", "Transformer",
"ModelConfig", "ModelConfig",
"TrainConfig", "TrainConfig",
"DatasetLoader", "DatasetLoader",
"BpeTokenizer", "BpeTokenizer",
"GenerationRequest", "GenerationRequest",
"LoopGenerator", "LoopGenerator",
"StreamGenerator", "StreamGenerator",
"BatchGenerator", "BatchGenerator",
"EmbeddingEncoder", "EmbeddingEncoder",
"GeneratorFactory", "GeneratorFactory",
"Trainer", "Trainer",
"StrategyFactory", "StrategyFactory",
"SchedulerFactory" "SchedulerFactory",
] ]

View File

@ -4,7 +4,7 @@ from khaosz.config.schedule_config import (
ScheduleConfig, ScheduleConfig,
CosineScheduleConfig, CosineScheduleConfig,
SGDRScheduleConfig, SGDRScheduleConfig,
ScheduleConfigFactory ScheduleConfigFactory,
) )
from khaosz.config.train_config import TrainConfig from khaosz.config.train_config import TrainConfig
@ -13,11 +13,9 @@ __all__ = [
# Base I/O # Base I/O
"BaseModelIO", "BaseModelIO",
"ModelParameter", "ModelParameter",
# Model configuration # Model configuration
"ModelConfig", "ModelConfig",
"TrainConfig", "TrainConfig",
# Schedule configuration # Schedule configuration
"ScheduleConfig", "ScheduleConfig",
"CosineScheduleConfig", "CosineScheduleConfig",

View File

@ -25,10 +25,9 @@ class ModelConfig:
use_qk_norm: Optional[bool] = None use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None use_gated_attention: Optional[bool] = None
def load(self, config_path: str) -> Self: def load(self, config_path: str) -> Self:
config = {} config = {}
with open(config_path, 'r') as f: with open(config_path, "r") as f:
config.update(json.load(f)) config.update(json.load(f))
for key, value in config.items(): for key, value in config.items():
@ -39,5 +38,5 @@ class ModelConfig:
def save(self, config_path: str): def save(self, config_path: str):
config_dict = {k: v for k, v in asdict(self).items() if v is not None} config_dict = {k: v for k, v in asdict(self).items() if v is not None}
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4) json.dump(config_dict, f, indent=4)

View File

@ -9,21 +9,20 @@ from khaosz.data.tokenizer import BpeTokenizer
from khaosz.config.model_config import ModelConfig from khaosz.config.model_config import ModelConfig
from khaosz.model.transformer import Transformer from khaosz.model.transformer import Transformer
@dataclass @dataclass
class BaseModelIO: class BaseModelIO:
"""Base class for model I/O operations.""" """Base class for model I/O operations."""
model: Optional[nn.Module] = field( model: Optional[nn.Module] = field(
default=None, default=None, metadata={"help": "Transformer model."}
metadata={"help": "Transformer model."}
) )
tokenizer: BpeTokenizer = field( tokenizer: BpeTokenizer = field(
default_factory=BpeTokenizer, default_factory=BpeTokenizer, metadata={"help": "Tokenizer for the model."}
metadata={"help": "Tokenizer for the model."}
) )
config: ModelConfig = field( config: ModelConfig = field(
default_factory=ModelConfig, default_factory=ModelConfig,
metadata={"help": "Transformer model configuration."} metadata={"help": "Transformer model configuration."},
) )
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]: def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
@ -32,7 +31,7 @@ class BaseModelIO:
return { return {
"model": dir_path / "model.safetensors", "model": dir_path / "model.safetensors",
"config": dir_path / "config.json", "config": dir_path / "config.json",
"tokenizer": dir_path / "tokenizer.json" "tokenizer": dir_path / "tokenizer.json",
} }
def save_components(self, save_dir: Union[str, Path]): def save_components(self, save_dir: Union[str, Path]):
@ -80,4 +79,3 @@ class ModelParameter(BaseModelIO):
def load(cls, load_dir: Union[str, Path]) -> "ModelParameter": def load(cls, load_dir: Union[str, Path]) -> "ModelParameter":
instance = cls() instance = cls()
return instance.load_components(load_dir) return instance.load_components(load_dir)

View File

@ -14,16 +14,14 @@ class ScheduleConfig(ABC):
default="cosine", default="cosine",
metadata={ metadata={
"help": "Type of learning rate schedule.", "help": "Type of learning rate schedule.",
"choices": ["cosine", "sgdr"] "choices": ["cosine", "sgdr"],
} },
) )
warmup_steps: int = field( warmup_steps: int = field(
default=1000, default=1000, metadata={"help": "Number of warmup steps."}
metadata={"help": "Number of warmup steps."}
) )
min_rate: float = field( min_rate: float = field(
default=0.05, default=0.05, metadata={"help": "Minimum learning rate multiplier."}
metadata={"help": "Minimum learning rate multiplier."}
) )
@abstractmethod @abstractmethod
@ -34,7 +32,9 @@ class ScheduleConfig(ABC):
def validate(self) -> None: def validate(self) -> None:
"""Validate configuration parameters.""" """Validate configuration parameters."""
if self.warmup_steps < 0: if self.warmup_steps < 0:
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}") raise ValueError(
f"warmup_steps must be non-negative, got {self.warmup_steps}"
)
if not 0 <= self.min_rate <= 1: if not 0 <= self.min_rate <= 1:
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}") raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
@ -44,8 +44,7 @@ class CosineScheduleConfig(ScheduleConfig):
"""Cosine annealing learning rate schedule configuration.""" """Cosine annealing learning rate schedule configuration."""
total_steps: int = field( total_steps: int = field(
default=None, default=None, metadata={"help": "Total training steps for cosine schedule."}
metadata={"help": "Total training steps for cosine schedule."}
) )
def __post_init__(self) -> None: def __post_init__(self) -> None:
@ -60,13 +59,15 @@ class CosineScheduleConfig(ScheduleConfig):
"schedule_type": self.schedule_type, "schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps, "warmup_steps": self.warmup_steps,
"lr_decay_steps": self.total_steps - self.warmup_steps, "lr_decay_steps": self.total_steps - self.warmup_steps,
"min_rate": self.min_rate "min_rate": self.min_rate,
} }
def validate(self) -> None: def validate(self) -> None:
super().validate() super().validate()
if self.total_steps is not None and self.total_steps <= self.warmup_steps: if self.total_steps is not None and self.total_steps <= self.warmup_steps:
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})") raise ValueError(
f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})"
)
@dataclass @dataclass
@ -74,12 +75,10 @@ class SGDRScheduleConfig(ScheduleConfig):
"""Stochastic Gradient Descent with Warm Restarts schedule configuration.""" """Stochastic Gradient Descent with Warm Restarts schedule configuration."""
cycle_length: int = field( cycle_length: int = field(
default=1000, default=1000, metadata={"help": "Length of the first cycle in steps."}
metadata={"help": "Length of the first cycle in steps."}
) )
t_mult: int = field( t_mult: int = field(
default=2, default=2, metadata={"help": "Multiplier for cycle length growth."}
metadata={"help": "Multiplier for cycle length growth."}
) )
def __post_init__(self) -> None: def __post_init__(self) -> None:
@ -92,7 +91,7 @@ class SGDRScheduleConfig(ScheduleConfig):
"warmup_steps": self.warmup_steps, "warmup_steps": self.warmup_steps,
"cycle_length": self.cycle_length, "cycle_length": self.cycle_length,
"min_rate": self.min_rate, "min_rate": self.min_rate,
"t_mult": self.t_mult "t_mult": self.t_mult,
} }
def validate(self) -> None: def validate(self) -> None:

View File

@ -10,127 +10,92 @@ from typing import Callable, List, Optional
@dataclass @dataclass
class TrainConfig: class TrainConfig:
# basic setting # basic setting
model: nn.Module = field( model: nn.Module = field(default=None, metadata={"help": "Model for training."})
default=None, strategy: str = field(default=None, metadata={"help": "Training strategy."})
metadata={"help": "Model for training."} dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."})
)
strategy: str = field(
default=None,
metadata={"help": "Training strategy."}
)
dataset: Dataset = field(
default=None,
metadata={"help": "Dataset for training."}
)
optimizer_fn: Callable[[nn.Module], Optimizer] = field( optimizer_fn: Callable[[nn.Module], Optimizer] = field(
default=None, default=None, metadata={"help": "Optimizer factory for training."}
metadata={"help": "Optimizer factory for training."}
) )
scheduler_fn: Callable[[Optimizer], LRScheduler] = field( scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
default=None, default=None, metadata={"help": "Scheduler factory for training."}
metadata={"help": "Scheduler factory for training."}
)
n_epoch: int = field(
default=1,
metadata={"help": "Number of epochs for training."}
)
batch_size: int = field(
default=4,
metadata={"help": "Batch size for training."}
) )
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
accumulation_steps: int = field( accumulation_steps: int = field(
default=1, default=1, metadata={"help": "Number of iterations between steps."}
metadata={"help": "Number of iterations between steps."}
) )
max_grad_norm: float = field( max_grad_norm: float = field(
default=1.0, default=1.0, metadata={"help": "Maximum gradient norm."}
metadata={"help": "Maximum gradient norm."}
) )
# checkpoint setting # checkpoint setting
start_epoch: int = field( start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
default=0,
metadata={"help": "Start epoch for training."}
)
start_batch: int = field( start_batch: int = field(
default=0, default=0, metadata={"help": "Start batch iteration for training."}
metadata={"help": "Start batch iteration for training."}
) )
checkpoint_dir: str = field( ckpt_dir: str = field(
default="./checkpoint", default="./checkpoint", metadata={"help": "Checkpoint directory."}
metadata={"help": "Checkpoint directory."}
) )
checkpoint_interval: int = field( ckpt_interval: int = field(
default=5000, default=5000, metadata={"help": "Number of iterations between checkpoints."}
metadata={"help": "Number of iterations between checkpoints."}
) )
# dataloader setting # dataloader setting
random_seed: int = field( random_seed: int = field(default=3407, metadata={"help": "Random seed."})
default=3407,
metadata={"help": "Random seed."}
)
num_workers: int = field( num_workers: int = field(
default=0, default=0, metadata={"help": "Number of workers for dataloader."}
metadata={"help": "Number of workers for dataloader."}
) )
prefetch_factor: Optional[int] = field( prefetch_factor: Optional[int] = field(
default=None, default=None, metadata={"help": "Prefetch factor for dataloader."}
metadata={"help": "Prefetch factor for dataloader."}
) )
pin_memory: bool = field( pin_memory: bool = field(
default=False, default=False, metadata={"help": "Pin memory for dataloader."}
metadata={"help": "Pin memory for dataloader."}
) )
# distributed training # distributed training
nprocs: int = field( nprocs: int = field(
default=1, default=1, metadata={"help": "Number of processes for distributed training."}
metadata={"help": "Number of processes for distributed training."}
) )
backend: str = field( backend: str = field(
default="nccl", default="nccl", metadata={"help": "Distributed training backend."}
metadata={"help": "Distributed training backend."}
) )
master_addr: str = field( master_addr: str = field(
default="localhost", default="localhost",
metadata={"help": "Master address for distributed training."} metadata={"help": "Master address for distributed training."},
) )
master_port: str = field( master_port: str = field(
default="29500", default="29500", metadata={"help": "Master port for distributed training."}
metadata={"help": "Master port for distributed training."}
) )
parallel_wrapper: Optional[Callable] = field( parallel_wrapper: Optional[Callable] = field(
default=None, default=None, metadata={"help": "Parallel function for training."}
metadata={"help": "Parallel function for training."}
) )
state_dict_fn: Optional[Callable] = field( state_dict_fn: Optional[Callable] = field(
default=None, default=None, metadata={"help": "Parallel function for state dict saving."}
metadata={"help": "Parallel function for state dict saving."}
) )
# others # others
device_ids: Optional[List[int]] = field( device_ids: Optional[List[int]] = field(
default=None, default=None, metadata={"help": "Device ids for distributed training."}
metadata={"help": "Device ids for distributed training."}
) )
device_type: str = field( device_type: str = field(
default="cuda", default="cuda", metadata={"help": "Device type for distributed training."}
metadata={"help": "Device type for distributed training."}
) )
extra_kwargs: dict = field( extra_kwargs: dict = field(
default_factory=dict, default_factory=dict, metadata={"help": "Other arguments."}
metadata={"help": "Other arguments."}
) )
def __post_init__(self): def __post_init__(self):
self.validate() self.validate()
def validate(self): def validate(self):
required_fields = ["model", "strategy", "dataset", "optimizer_fn", "scheduler_fn"] required_fields = [
"model",
"strategy",
"dataset",
"optimizer_fn",
"scheduler_fn",
]
for field_name in required_fields: for field_name in required_fields:
if getattr(self, field_name) is None: if getattr(self, field_name) is None:
raise ValueError(f"{field_name} is required.") raise ValueError(f"{field_name} is required.")

View File

@ -6,7 +6,7 @@ from khaosz.data.dataset import (
GRPODataset, GRPODataset,
MultiSegmentFetcher, MultiSegmentFetcher,
DatasetLoader, DatasetLoader,
DatasetFactory DatasetFactory,
) )
from khaosz.data.tokenizer import BpeTokenizer from khaosz.data.tokenizer import BpeTokenizer
@ -15,21 +15,17 @@ from khaosz.data.sampler import ResumableDistributedSampler
__all__ = [ __all__ = [
# Base classes # Base classes
"BaseDataset", "BaseDataset",
# Dataset implementations # Dataset implementations
"SEQDataset", "SEQDataset",
"SFTDataset", "SFTDataset",
"DPODataset", "DPODataset",
"GRPODataset", "GRPODataset",
# Fetchers # Fetchers
"MultiSegmentFetcher", "MultiSegmentFetcher",
# Factory (DatasetLoader is alias for backward compatibility) # Factory (DatasetLoader is alias for backward compatibility)
"DatasetLoader", "DatasetLoader",
"DatasetFactory", "DatasetFactory",
# Tokenizer and sampler # Tokenizer and sampler
"BpeTokenizer", "BpeTokenizer",
"ResumableDistributedSampler" "ResumableDistributedSampler",
] ]

View File

@ -41,7 +41,9 @@ class BaseSegmentFetcher:
Returns: Returns:
Concatenated tensor of data in the specified range Concatenated tensor of data in the specified range
""" """
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length): if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds") raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx: if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long) return torch.tensor([], dtype=torch.long)
@ -71,8 +73,7 @@ class MultiSegmentFetcher:
def __init__(self, muti_segments: Dict): def __init__(self, muti_segments: Dict):
self.muti_keys = list(muti_segments.keys()) self.muti_keys = list(muti_segments.keys())
self.muti_fetchers = { self.muti_fetchers = {
key: BaseSegmentFetcher(segments) key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items()
for key, segments in muti_segments.items()
} }
def __len__(self) -> int: def __len__(self) -> int:
@ -80,7 +81,9 @@ class MultiSegmentFetcher:
len_list = [len(seg) for seg in self.muti_fetchers.values()] len_list = [len(seg) for seg in self.muti_fetchers.values()]
return min(len_list) return min(len_list)
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict: def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys. """Fetch data for specific keys.
Args: Args:
@ -189,11 +192,13 @@ class DatasetFactory:
Returns: Returns:
Decorator function that registers the dataset class Decorator function that registers the dataset class
""" """
def decorator(dataset_cls: type) -> type: def decorator(dataset_cls: type) -> type:
if not issubclass(dataset_cls, BaseDataset): if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset") raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
cls.DATASET_MAP[name] = dataset_cls cls.DATASET_MAP[name] = dataset_cls
return dataset_cls return dataset_cls
return decorator return decorator
@classmethod @classmethod
@ -223,7 +228,13 @@ class DatasetFactory:
return dataset_cls(window_size, stride) return dataset_cls(window_size, stride)
@classmethod @classmethod
def load(cls, train_type: str, load_path: str, window_size: int, stride: Optional[int] = None) -> BaseDataset: def load(
cls,
train_type: str,
load_path: str,
window_size: int,
stride: Optional[int] = None,
) -> BaseDataset:
"""Create and load a dataset in one step. """Create and load a dataset in one step.
Args: Args:
@ -286,8 +297,12 @@ class SFTDataset(BaseDataset):
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool) dtype=torch.long
)
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
dtype=torch.bool
)
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask} return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
@ -307,10 +322,19 @@ class DPODataset(BaseDataset):
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long) chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long) rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool) chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool) dtype=torch.bool
)
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(
dtype=torch.bool
)
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask} return {
"chosen": chosen,
"rejected": rejected,
"chosen_mask": chosen_mask,
"rejected_mask": rejected_mask,
}
@DatasetFactory.register("grpo") @DatasetFactory.register("grpo")
@ -331,7 +355,12 @@ class GRPODataset(BaseDataset):
masks = self._fetch_data(begin_idx, end_idx, "masks") masks = self._fetch_data(begin_idx, end_idx, "masks")
rewards = self._fetch_data(begin_idx, end_idx, "rewards") rewards = self._fetch_data(begin_idx, end_idx, "rewards")
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards} return {
"prompts": prompts,
"responses": responses,
"masks": masks,
"rewards": rewards,
}
# Backward compatibility alias # Backward compatibility alias

View File

@ -9,12 +9,12 @@ class ResumableDistributedSampler(Sampler[int]):
def __init__( def __init__(
self, self,
data_source: Dataset, data_source: Dataset,
start_epoch: int=0, start_epoch: int = 0,
start_iter: int=0, start_iter: int = 0,
seed: int=42, seed: int = 42,
drop_last: bool=False, drop_last: bool = False,
shuffle: bool=True, shuffle: bool = True,
process_group: Optional[dist.ProcessGroup]=None, process_group: Optional[dist.ProcessGroup] = None,
): ):
self.epoch = start_epoch self.epoch = start_epoch
self.iter = start_iter self.iter = start_iter
@ -40,7 +40,7 @@ class ResumableDistributedSampler(Sampler[int]):
self.drop_last = drop_last self.drop_last = drop_last
self.shuffle = shuffle self.shuffle = shuffle
offset = 0 if drop_last else self.num_replicas - 1 offset = 0 if drop_last else self.num_replicas - 1
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
self.total_size = self.num_samples_per_replica * self.num_replicas self.total_size = self.num_samples_per_replica * self.num_replicas
@ -58,10 +58,10 @@ class ResumableDistributedSampler(Sampler[int]):
padding_size = self.total_size - len(indices) padding_size = self.total_size - len(indices)
indices += indices[:padding_size] indices += indices[:padding_size]
local_indices = indices[self.rank:self.total_size:self.num_replicas] local_indices = indices[self.rank : self.total_size : self.num_replicas]
self.iter = self.iter % self.num_samples_per_replica self.iter = self.iter % self.num_samples_per_replica
self._indices = local_indices[self.iter:] self._indices = local_indices[self.iter :]
def __iter__(self): def __iter__(self):
if self._indices is None: if self._indices is None:

View File

@ -10,15 +10,17 @@ from torch import Tensor
from typing import Any, Dict, List from typing import Any, Dict, List
from khaosz.parallel.setup import get_rank from khaosz.parallel.setup import get_rank
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True) os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5") full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, 'w') as f: with h5py.File(full_file_path, "w") as f:
for key, tensors in tensor_group.items(): for key, tensors in tensor_group.items():
grp = f.create_group(key) grp = f.create_group(key)
for idx, tensor in enumerate(tensors): for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy() arr = tensor.cpu().numpy()
grp.create_dataset(f'data_{idx}', data=arr) grp.create_dataset(f"data_{idx}", data=arr)
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]: def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {} tensor_group: Dict[str, List[Tensor]] = {}
@ -27,7 +29,7 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5")) h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files: for h5_file in h5_files:
with h5py.File(h5_file, 'r') as f: with h5py.File(h5_file, "r") as f:
for key in f.keys(): for key in f.keys():
grp = f[key] grp = f[key]
dsets = [] dsets = []

View File

@ -12,15 +12,16 @@ class BpeTokenizer:
model = BPE() model = BPE()
self._tokenizer = Tokenizer(model) self._tokenizer = Tokenizer(model)
self._tokenizer.normalizer = normalizers.Sequence([ self._tokenizer.normalizer = normalizers.Sequence(
normalizers.NFC(), [normalizers.NFC(), normalizers.Strip()]
normalizers.Strip() )
])
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence([ self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
pre_tokenizers.UnicodeScripts(), [
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True) pre_tokenizers.UnicodeScripts(),
]) pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True),
]
)
self._tokenizer.decoder = decoders.ByteLevel() self._tokenizer.decoder = decoders.ByteLevel()
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
@ -28,10 +29,21 @@ class BpeTokenizer:
if path is not None: if path is not None:
self._tokenizer = Tokenizer.from_file(path) self._tokenizer = Tokenizer.from_file(path)
def _prepare_trainer(self, vocab_size: int, min_freq: int, reserved_token_size: int, max_token_length=18) -> tuple: def _prepare_trainer(
self,
vocab_size: int,
min_freq: int,
reserved_token_size: int,
max_token_length=18,
) -> tuple:
assert reserved_token_size > len(self._special_tokens) assert reserved_token_size > len(self._special_tokens)
reserved_tokens = [f"<|reserve{i:02d}|>" for i in range(reserved_token_size - len(self._special_tokens))] reserved_tokens = [
detail_vocab_size = vocab_size - (len(reserved_tokens) + len(self._special_tokens)) f"<|reserve{i:02d}|>"
for i in range(reserved_token_size - len(self._special_tokens))
]
detail_vocab_size = vocab_size - (
len(reserved_tokens) + len(self._special_tokens)
)
alphabet = pre_tokenizers.ByteLevel.alphabet() alphabet = pre_tokenizers.ByteLevel.alphabet()
min_size = len(alphabet) + len(self._control_tokens) min_size = len(alphabet) + len(self._control_tokens)
@ -53,16 +65,18 @@ class BpeTokenizer:
trainer, _, reserved_tokens = self._prepare_trainer( trainer, _, reserved_tokens = self._prepare_trainer(
vocab_size=vocab_size, vocab_size=vocab_size,
min_freq=min_freq, min_freq=min_freq,
reserved_token_size=reserved_token_size reserved_token_size=reserved_token_size,
) )
self._tokenizer.train(files=files, trainer=trainer) self._tokenizer.train(files=files, trainer=trainer)
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens) self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
def train_from_iterator(self, iterator, vocab_size, min_freq, reserved_token_size=100): def train_from_iterator(
self, iterator, vocab_size, min_freq, reserved_token_size=100
):
trainer, _, reserved_tokens = self._prepare_trainer( trainer, _, reserved_tokens = self._prepare_trainer(
vocab_size=vocab_size, vocab_size=vocab_size,
min_freq=min_freq, min_freq=min_freq,
reserved_token_size=reserved_token_size reserved_token_size=reserved_token_size,
) )
self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer) self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer)
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens) self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
@ -73,15 +87,26 @@ class BpeTokenizer:
def load(self, path): def load(self, path):
self._tokenizer = Tokenizer.from_file(path) self._tokenizer = Tokenizer.from_file(path)
def encode(self, tokens: Union[str, List[str]], out_ids: bool=True, add_special_tokens: bool=False) -> List: def encode(
self,
tokens: Union[str, List[str]],
out_ids: bool = True,
add_special_tokens: bool = False,
) -> List:
if isinstance(tokens, str): if isinstance(tokens, str):
encoded: Encoding = self._tokenizer.encode(tokens, add_special_tokens=add_special_tokens) encoded: Encoding = self._tokenizer.encode(
tokens, add_special_tokens=add_special_tokens
)
return encoded.ids if out_ids else encoded.tokens return encoded.ids if out_ids else encoded.tokens
elif isinstance(tokens, list): elif isinstance(tokens, list):
encoded_list: List[Encoding] = self._tokenizer.encode_batch(tokens, add_special_tokens=add_special_tokens) encoded_list: List[Encoding] = self._tokenizer.encode_batch(
return [encoded.ids if out_ids else encoded.tokens for encoded in encoded_list] tokens, add_special_tokens=add_special_tokens
)
return [
encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
]
def decode(self, tokens: List[int], skip_special_tokens: bool=True) -> str: def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def __len__(self) -> int: def __len__(self) -> int:

View File

@ -11,7 +11,7 @@ from khaosz.inference.generator import (
StreamGenerator, StreamGenerator,
BatchGenerator, BatchGenerator,
EmbeddingEncoder, EmbeddingEncoder,
GeneratorFactory GeneratorFactory,
) )
__all__ = [ __all__ = [
@ -19,11 +19,10 @@ __all__ = [
"GeneratorCore", "GeneratorCore",
"EmbeddingEncoderCore", "EmbeddingEncoderCore",
"KVCacheManager", "KVCacheManager",
"GenerationRequest", "GenerationRequest",
"LoopGenerator", "LoopGenerator",
"StreamGenerator", "StreamGenerator",
"BatchGenerator", "BatchGenerator",
"EmbeddingEncoder", "EmbeddingEncoder",
"GeneratorFactory" "GeneratorFactory",
] ]

View File

@ -12,7 +12,7 @@ def apply_sampling_strategies(
temperature: float, temperature: float,
top_k: int, top_k: int,
top_p: float, top_p: float,
filter_value: float = -float("inf") filter_value: float = -float("inf"),
) -> Tensor: ) -> Tensor:
""" """
Apply sampling strategies to the logits tensor. Apply sampling strategies to the logits tensor.
@ -47,9 +47,7 @@ def apply_sampling_strategies(
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_( indices_to_remove.scatter_(
dim=1, dim=1, index=sorted_indices, src=sorted_indices_to_remove
index=sorted_indices,
src=sorted_indices_to_remove
) )
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
@ -60,10 +58,15 @@ def apply_sampling_strategies(
@contextmanager @contextmanager
def disable_random_init(): def disable_random_init():
init_functions = [ init_functions = [
'xavier_normal_', 'xavier_uniform_', "xavier_normal_",
'kaiming_normal_', 'kaiming_uniform_', "xavier_uniform_",
'zeros_', 'ones_', 'constant_', "kaiming_normal_",
'normal_', 'uniform_' "kaiming_uniform_",
"zeros_",
"ones_",
"constant_",
"normal_",
"uniform_",
] ]
original_funcs = {} original_funcs = {}
for name in init_functions: for name in init_functions:
@ -91,8 +94,8 @@ class GeneratorCore:
top_p: float, top_p: float,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None, kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0 start_pos: int = 0,
)-> Tuple[Tensor, int]: ) -> Tuple[Tensor, int]:
with torch.inference_mode(): with torch.inference_mode():
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos) outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
@ -115,13 +118,20 @@ class GeneratorCore:
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None, kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0, start_pos: int = 0,
callback: Optional[Callable[..., Any]] = None callback: Optional[Callable[..., Any]] = None,
) -> List[int]: ) -> List[int]:
cur_cache_pos = start_pos cur_cache_pos = start_pos
for _ in range(len(ids), self.config.max_len): for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator( next_token_id, cache_increase = self.generate_iterator(
input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos) input_ids,
temperature,
top_k,
top_p,
attn_mask,
kv_caches,
cur_cache_pos,
)
input_ids = next_token_id input_ids = next_token_id
ids.append(next_token_id.item()) ids.append(next_token_id.item())
@ -157,14 +167,17 @@ class EmbeddingEncoderCore:
for i, seq in enumerate(batch_ids): for i, seq in enumerate(batch_ids):
if len(seq) > max_model_len: if len(seq) > max_model_len:
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)] fragments = [
seq[j : j + max_model_len]
for j in range(0, len(seq), max_model_len)
]
all_fragments.extend(fragments) all_fragments.extend(fragments)
fragment_origin_idx.extend([i] * len(fragments)) fragment_origin_idx.extend([i] * len(fragments))
else: else:
all_fragments.append(seq) all_fragments.append(seq)
fragment_origin_idx.append(i) fragment_origin_idx.append(i)
#if empty fragments # if empty fragments
if not all_fragments or not ids: if not all_fragments or not ids:
return [] if with_batch else torch.tensor([]) return [] if with_batch else torch.tensor([])
@ -190,11 +203,17 @@ class EmbeddingEncoderCore:
sentence_embs: List[Tensor] = [] sentence_embs: List[Tensor] = []
for i in range(len(batch_ids)): for i in range(len(batch_ids)):
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i] indices = [
idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i
]
if indices: if indices:
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size] sum_frags = torch.sum(
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1] fragment_embs[indices, :, :], dim=1
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size] ) # [frags, hidden_size]
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(
1
) # [frags, 1]
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
sentence_embs.append(emb.flatten()) sentence_embs.append(emb.flatten())
if with_batch: if with_batch:
@ -213,7 +232,7 @@ class KVCacheManager:
config: ModelConfig, config: ModelConfig,
batch_size: int, batch_size: int,
device: torch.device = "cuda", device: torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16 dtype: torch.dtype = torch.bfloat16,
): ):
self.batch_size = batch_size self.batch_size = batch_size
self.device = device self.device = device
@ -221,7 +240,7 @@ class KVCacheManager:
self.num_layers = config.n_layers self.num_layers = config.n_layers
self.max_len = config.max_len self.max_len = config.max_len
self.num_heads = config.n_kv_heads self.num_heads = config.n_kv_heads
self.head_dim = config.dim //config.n_heads self.head_dim = config.dim // config.n_heads
self._kv_cache: Tuple[Tensor, Tensor] = None self._kv_cache: Tuple[Tensor, Tensor] = None
self._seq_mask: Tensor = None self._seq_mask: Tensor = None
@ -229,15 +248,31 @@ class KVCacheManager:
def _initialize(self): def _initialize(self):
k_cache = torch.empty( k_cache = torch.empty(
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim), (
device=self.device, dtype=self.dtype self.batch_size,
self.max_len,
self.num_layers,
self.num_heads,
self.head_dim,
),
device=self.device,
dtype=self.dtype,
) )
v_cache = torch.empty( v_cache = torch.empty(
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim), (
device=self.device, dtype=self.dtype self.batch_size,
self.max_len,
self.num_layers,
self.num_heads,
self.head_dim,
),
device=self.device,
dtype=self.dtype,
) )
self._kv_cache = (k_cache, v_cache) self._kv_cache = (k_cache, v_cache)
self._seq_mask = torch.ones((self.batch_size, self.max_len), device=self.device, dtype=torch.bool) self._seq_mask = torch.ones(
(self.batch_size, self.max_len), device=self.device, dtype=torch.bool
)
def update(self, active_mask: Tensor): def update(self, active_mask: Tensor):
k_cache, v_cache = self._kv_cache k_cache, v_cache = self._kv_cache
@ -253,8 +288,8 @@ class KVCacheManager:
def set_seq_mask(self, input_ids: Tensor, pad_id: int): def set_seq_mask(self, input_ids: Tensor, pad_id: int):
batch_size, seq_len = input_ids.shape batch_size, seq_len = input_ids.shape
bool_mask = (input_ids != pad_id) bool_mask = input_ids != pad_id
self._seq_mask[: batch_size, : seq_len] = bool_mask self._seq_mask[:batch_size, :seq_len] = bool_mask
def get_kvcache(self) -> Tuple[Tensor, Tensor]: def get_kvcache(self) -> Tuple[Tensor, Tensor]:
return self._kv_cache return self._kv_cache

View File

@ -8,10 +8,11 @@ from khaosz.config.param_config import ModelParameter
HistoryType = List[Tuple[str, str]] HistoryType = List[Tuple[str, str]]
def build_prompt( def build_prompt(
query: str, query: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
history: Optional[HistoryType] = None history: Optional[HistoryType] = None,
) -> str: ) -> str:
""" """
Build prompt in ChatML format for query and history. Build prompt in ChatML format for query and history.
@ -79,6 +80,7 @@ class GenerationRequest:
system_prompt: System prompt for the conversation. system_prompt: System prompt for the conversation.
stream: Whether to use streaming generation. stream: Whether to use streaming generation.
""" """
top_k: int top_k: int
top_p: float top_p: float
temperature: float temperature: float
@ -146,9 +148,12 @@ class StreamGenerator(GeneratorCore):
for _ in range(len(ids), self.config.max_len): for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator( next_token_id, cache_increase = self.generate_iterator(
input_ids, request.temperature, request.top_k, request.top_p, input_ids,
request.temperature,
request.top_k,
request.top_p,
kv_caches=kv_caches, kv_caches=kv_caches,
start_pos=cur_cache_pos start_pos=cur_cache_pos,
) )
input_ids = next_token_id input_ids = next_token_id
@ -172,7 +177,10 @@ class BatchGenerator(GeneratorCore):
if request.history is None: if request.history is None:
request.history = [[] for _ in range(batch_size)] request.history = [[] for _ in range(batch_size)]
prompts = [build_prompt(query, history) for query, history in zip(request.query, request.history)] prompts = [
build_prompt(query, history)
for query, history in zip(request.query, request.history)
]
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts] ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id) ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
@ -189,13 +197,16 @@ class BatchGenerator(GeneratorCore):
while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0: while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0:
kv_caches = cache_manager.get_kvcache() kv_caches = cache_manager.get_kvcache()
attn_mask =cache_manager.get_seq_mask() attn_mask = cache_manager.get_seq_mask()
next_token_id, cache_increase = self.generate_iterator( next_token_id, cache_increase = self.generate_iterator(
input_tensor, request.temperature, request.top_k, request.top_p, input_tensor,
request.temperature,
request.top_k,
request.top_p,
attn_mask=attn_mask, attn_mask=attn_mask,
kv_caches=kv_caches, kv_caches=kv_caches,
start_pos=cur_cache_pos start_pos=cur_cache_pos,
) )
cur_cache_pos += cache_increase cur_cache_pos += cache_increase
@ -248,7 +259,9 @@ class GeneratorFactory:
""" """
@staticmethod @staticmethod
def create_generator(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore: def create_generator(
parameter: ModelParameter, request: GenerationRequest
) -> GeneratorCore:
"""Create a generator based on request characteristics. """Create a generator based on request characteristics.
Args: Args:
@ -282,7 +295,9 @@ class GeneratorFactory:
return EmbeddingEncoder(parameter) return EmbeddingEncoder(parameter)
@classmethod @classmethod
def create(cls, parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore: def create(
cls, parameter: ModelParameter, request: GenerationRequest
) -> GeneratorCore:
"""Convenience method that delegates to create_generator. """Convenience method that delegates to create_generator.
Args: Args:
@ -293,4 +308,3 @@ class GeneratorFactory:
Generator instance Generator instance
""" """
return cls.create_generator(parameter, request) return cls.create_generator(parameter, request)

View File

@ -7,11 +7,4 @@ from khaosz.model.module import (
) )
from khaosz.model.transformer import Transformer from khaosz.model.transformer import Transformer
__all__ = [ __all__ = ["Linear", "RMSNorm", "MLP", "GQA", "DecoderBlock", "Transformer"]
"Linear",
"RMSNorm",
"MLP",
"GQA",
"DecoderBlock",
"Transformer"
]

View File

@ -25,11 +25,12 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
.reshape(bs, slen, n_heads * n_rep, head_dim) .reshape(bs, slen, n_heads * n_rep, head_dim)
) )
def get_rotary_emb( def get_rotary_emb(
dim: int, dim: int,
max_len: int, max_len: int,
base: float = 10000, base: float = 10000,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
Get the rotary embedding for the given dimension and maximum length. Get the rotary embedding for the given dimension and maximum length.
Args: Args:
@ -46,6 +47,7 @@ def get_rotary_emb(
return torch.cos(freqs).float(), torch.sin(freqs).float() return torch.cos(freqs).float(), torch.sin(freqs).float()
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor: def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
""" """
Apply rotary embedding to the input tensor using cos/sin form. Apply rotary embedding to the input tensor using cos/sin form.
@ -69,13 +71,13 @@ def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tens
x_imag_rot = x_real * sin + x_imag * cos x_imag_rot = x_real * sin + x_imag * cos
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2] x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim] x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
return x_out.to(dtype) return x_out.to(dtype)
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int, base: int=10000): def __init__(self, dim: int, max_len: int, base: int = 10000):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.max_len = max_len self.max_len = max_len
@ -89,7 +91,7 @@ class RotaryEmbedding(nn.Module):
self.register_buffer("sin_cached", sin_cached, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False)
self.max_len_cached = max_len self.max_len_cached = max_len
def forward(self, x: Tensor, start_pos: int=0) -> Tuple[Tensor, Tensor]: def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
seq_len = x.size(1) seq_len = x.size(1)
if self.max_len_cached < seq_len + start_pos: if self.max_len_cached < seq_len + start_pos:
@ -115,11 +117,11 @@ class RMSNorm(nn.Module):
def __init__(self, dim, norm_eps): def __init__(self, dim, norm_eps):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(dim)) self.weight = nn.Parameter(torch.ones(dim))
self.normalized_shape = (dim, ) self.normalized_shape = (dim,)
self.norm_eps = norm_eps self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps) rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps)
return rms.to(x.dtype) return rms.to(x.dtype)
@ -136,7 +138,6 @@ class MLP(nn.Module):
return out return out
class GQA(nn.Module): class GQA(nn.Module):
def __init__( def __init__(
self, self,
@ -146,7 +147,7 @@ class GQA(nn.Module):
use_qk_norm: bool, use_qk_norm: bool,
norm_eps: float, norm_eps: float,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int layer_id: int,
): ):
super().__init__() super().__init__()
assert dim % n_heads == 0 assert dim % n_heads == 0
@ -184,7 +185,7 @@ class GQA(nn.Module):
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None, mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0 start_pos: int = 0,
) -> Tensor: ) -> Tensor:
bsz, seq_len, _ = x.size() bsz, seq_len, _ = x.size()
is_causal = mask is None is_causal = mask is None
@ -202,19 +203,24 @@ class GQA(nn.Module):
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
# copy to cache # copy to cache
k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
# get cache # get cache
k = k_cache[:bsz, :start_pos + seq_len, self.layer_id] k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id] v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim) # (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal).permute(0, 2, 1, 3).contiguous().flatten(2) sdqa_out = (
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
.permute(0, 2, 1, 3)
.contiguous()
.flatten(2)
)
if self.use_gated_attention: if self.use_gated_attention:
sdqa_out = sdqa_out * F.sigmoid(self.gate(x)) sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
@ -235,7 +241,7 @@ class MLA(nn.Module):
qk_rope_head_dim: int, qk_rope_head_dim: int,
norm_eps: float, norm_eps: float,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int layer_id: int,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -270,7 +276,7 @@ class MLA(nn.Module):
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None, mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0 start_pos: int = 0,
) -> Tensor: ) -> Tensor:
bsz, seq_len, _ = x.size() bsz, seq_len, _ = x.size()
is_causal = mask is None is_causal = mask is None
@ -285,12 +291,13 @@ class MLA(nn.Module):
kv = kv.view(bsz, seq_len, self.n_kv_heads, -1) kv = kv.view(bsz, seq_len, self.n_kv_heads, -1)
k_nope, k_rope, v = torch.split( k_nope, k_rope, v = torch.split(
kv, kv, [self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], dim=-1
[self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim],
dim=-1
) )
q_nope, q_rope = q[..., :self.qk_nope_head_dim], q[..., self.qk_rope_head_dim:] q_nope, q_rope = (
q[..., : self.qk_nope_head_dim],
q[..., self.qk_rope_head_dim :],
)
q_rope = apply_rotary_emb(q_rope, rotary_emb) q_rope = apply_rotary_emb(q_rope, rotary_emb)
k_rope = apply_rotary_emb(k_rope, rotary_emb) k_rope = apply_rotary_emb(k_rope, rotary_emb)
@ -299,10 +306,10 @@ class MLA(nn.Module):
if kv_cache is not None: if kv_cache is not None:
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
k = k_cache[:bsz, :start_pos + seq_len, self.layer_id] k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id] v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
q = q.permute(0, 2, 1, 3) q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3)
@ -329,11 +336,18 @@ class DecoderBlock(nn.Module):
norm_eps: int, norm_eps: int,
use_qk_norm: bool, use_qk_norm: bool,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int layer_id: int,
): ):
super().__init__() super().__init__()
self.attention = GQA(dim, n_heads, n_kv_heads, self.attention = GQA(
use_qk_norm, norm_eps, use_gated_attention, layer_id) dim,
n_heads,
n_kv_heads,
use_qk_norm,
norm_eps,
use_gated_attention,
layer_id,
)
self.input_norm = RMSNorm(dim, norm_eps) self.input_norm = RMSNorm(dim, norm_eps)
self.mlp = MLP(dim, dim_ffn) self.mlp = MLP(dim, dim_ffn)
self.post_attention_norm = RMSNorm(dim, norm_eps) self.post_attention_norm = RMSNorm(dim, norm_eps)
@ -344,15 +358,11 @@ class DecoderBlock(nn.Module):
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0 start_pos: int = 0,
) -> Tensor: ) -> Tensor:
# attention # attention
attn_output = self.attention( attn_output = self.attention(
self.input_norm(x), self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos
rotary_emb,
attention_mask,
kv_cache,
start_pos
) )
x = attn_output + x x = attn_output + x

View File

@ -4,15 +4,21 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from typing import Any, Mapping, Optional, Tuple from typing import Any, Mapping, Optional, Tuple
from khaosz.config.model_config import ModelConfig from khaosz.config.model_config import ModelConfig
from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, RotaryEmbedding from khaosz.model.module import (
Embedding,
DecoderBlock,
Linear,
RMSNorm,
RotaryEmbedding,
)
def process_attention_mask( def process_attention_mask(
seq_mask: Tensor, seq_mask: Tensor,
input_tensor: Tensor, input_tensor: Tensor,
start_pos: int = 0, start_pos: int = 0,
is_causal: bool = False, is_causal: bool = False,
) -> Tensor: ) -> Tensor:
""" """
Create attention mask for GQA Create attention mask for GQA
Args: Args:
@ -40,16 +46,20 @@ def process_attention_mask(
return seq_mask return seq_mask
batch_size = seq_mask.size(0) batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool) seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len) # (bsz, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len) expanded_mask = seq_mask.unsqueeze(1).expand(
batch_size, seq_len, start_pos + seq_len
)
# (bsz, seq_len, start_pos + seq_len) # (bsz, seq_len, start_pos + seq_len)
if is_causal: if is_causal:
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos) expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device) attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1) attention_mask = attention_mask.masked_fill_(
~expanded_mask, -torch.finfo(dtype).max / 2
).unsqueeze(1)
# (bsz, 1, seq_len, seq_len + start_pos) # (bsz, 1, seq_len, seq_len + start_pos)
return attention_mask return attention_mask
@ -59,14 +69,26 @@ class Transformer(nn.Module):
def __init__(self, config: ModelConfig): def __init__(self, config: ModelConfig):
super().__init__() super().__init__()
self.config = config self.config = config
self.rotary_embeding = RotaryEmbedding(config.dim // config.n_heads, config.max_len) self.rotary_embeding = RotaryEmbedding(
config.dim // config.n_heads, config.max_len
)
self.embed_tokens = Embedding(config.vocab_size, config.dim) self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList(
DecoderBlock(config.dim, config.n_heads, config.dim_ffn, config.n_kv_heads, [
config.norm_eps, config.use_qk_norm, config.use_gated_attention, layer_id) DecoderBlock(
for layer_id in range(config.n_layers) config.dim,
]) config.n_heads,
config.dim_ffn,
config.n_kv_heads,
config.norm_eps,
config.use_qk_norm,
config.use_gated_attention,
layer_id,
)
for layer_id in range(config.n_layers)
]
)
self.norm = RMSNorm(config.dim, config.norm_eps) self.norm = RMSNorm(config.dim, config.norm_eps)
self.lm_head = Linear(config.dim, config.vocab_size) self.lm_head = Linear(config.dim, config.vocab_size)
@ -77,8 +99,8 @@ class Transformer(nn.Module):
self._init_parameters() self._init_parameters()
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False): def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
lm_head_key = 'lm_head.weight' lm_head_key = "lm_head.weight"
embed_key = 'embed_tokens.weight' embed_key = "embed_tokens.weight"
if self.config.tie_weight == True: if self.config.tie_weight == True:
# same tensor # same tensor
@ -90,11 +112,13 @@ class Transformer(nn.Module):
return super().load_state_dict(state_dict, strict, assign) return super().load_state_dict(state_dict, strict, assign)
def state_dict(self, destination=None, prefix='', keep_vars=False): def state_dict(self, destination=None, prefix="", keep_vars=False):
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) state_dict = super().state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
if self.config.tie_weight == True: if self.config.tie_weight == True:
lm_head_key = prefix + 'lm_head.weight' lm_head_key = prefix + "lm_head.weight"
if lm_head_key in state_dict: if lm_head_key in state_dict:
del state_dict[lm_head_key] del state_dict[lm_head_key]
@ -108,18 +132,16 @@ class Transformer(nn.Module):
def forward( def forward(
self, self,
input_ids: Tensor, input_ids: Tensor,
input_mask: Optional[Tensor]=None, input_mask: Optional[Tensor] = None,
persistent_key_values: Optional[Tuple[Tensor, Tensor]]=None, persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0 start_pos: int = 0,
) -> Tensor: ) -> Tensor:
assert input_ids.ndim == 2 assert input_ids.ndim == 2
x = self.embed_tokens(input_ids) x = self.embed_tokens(input_ids)
rotary_emb = self.rotary_embeding(x, start_pos) rotary_emb = self.rotary_embeding(x, start_pos)
attn_mask = process_attention_mask( attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
input_mask, x, start_pos, is_causal=True
)
for layer in self.layers: for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos) x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos)
@ -127,8 +149,4 @@ class Transformer(nn.Module):
hidden_states = self.norm(x) hidden_states = self.norm(x)
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
return { return {"logits": logits, "hidden_states": hidden_states}
"logits": logits,
"hidden_states": hidden_states
}

View File

@ -2,26 +2,20 @@ from khaosz.parallel.setup import (
get_world_size, get_world_size,
get_rank, get_rank,
get_current_device, get_current_device,
only_on_rank, only_on_rank,
setup_parallel, setup_parallel,
spawn_parallel_fn spawn_parallel_fn,
) )
from khaosz.parallel.module import ( from khaosz.parallel.module import RowParallelLinear, ColumnParallelLinear
RowParallelLinear,
ColumnParallelLinear
)
__all__ = [ __all__ = [
"get_world_size", "get_world_size",
"get_rank", "get_rank",
"get_current_device", "get_current_device",
"only_on_rank", "only_on_rank",
"setup_parallel", "setup_parallel",
"spawn_parallel_fn", "spawn_parallel_fn",
"RowParallelLinear", "RowParallelLinear",
"ColumnParallelLinear" "ColumnParallelLinear",
] ]

View File

@ -22,7 +22,7 @@ class RowParallelLinear(ParallelModel):
in_features: int, in_features: int,
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
reduce_results: bool = True reduce_results: bool = True,
): ):
super().__init__(process_group) super().__init__(process_group)
@ -32,7 +32,9 @@ class RowParallelLinear(ParallelModel):
self.reduce_results = reduce_results self.reduce_results = reduce_results
if in_features % self.world_size != 0: if in_features % self.world_size != 0:
raise ValueError(f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}") raise ValueError(
f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}"
)
self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank)) self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank))
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
@ -49,8 +51,8 @@ class RowParallelLinear(ParallelModel):
return output return output
def load_state_dict(self, state_dict: Dict[str, Tensor]): def load_state_dict(self, state_dict: Dict[str, Tensor]):
full_weight = state_dict.get('weight') full_weight = state_dict.get("weight")
full_bias = state_dict.get('bias') full_bias = state_dict.get("bias")
start_idx = self.rank * self.in_features_per_rank start_idx = self.rank * self.in_features_per_rank
end_idx = start_idx + self.in_features_per_rank end_idx = start_idx + self.in_features_per_rank
@ -68,7 +70,7 @@ class ColumnParallelLinear(ParallelModel):
in_features: int, in_features: int,
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
gather_results: bool = True gather_results: bool = True,
): ):
super().__init__(process_group) super().__init__(process_group)
@ -78,10 +80,16 @@ class ColumnParallelLinear(ParallelModel):
self.gather_results = gather_results self.gather_results = gather_results
if out_features % self.world_size != 0: if out_features % self.world_size != 0:
raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}") raise ValueError(
f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}"
)
self.weight = nn.Parameter(torch.empty(self.out_features_per_rank, self.in_features)) self.weight = nn.Parameter(
self.bias = nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None torch.empty(self.out_features_per_rank, self.in_features)
)
self.bias = (
nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None
)
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
output = F.linear(input, self.weight, self.bias) output = F.linear(input, self.weight, self.bias)
@ -94,8 +102,8 @@ class ColumnParallelLinear(ParallelModel):
return output return output
def load_state_dict(self, state_dict: Dict[str, Tensor]): def load_state_dict(self, state_dict: Dict[str, Tensor]):
full_weight = state_dict.get('weight') full_weight = state_dict.get("weight")
full_bias = state_dict.get('bias') full_bias = state_dict.get("bias")
start_idx = self.rank * self.out_features_per_rank start_idx = self.rank * self.out_features_per_rank
end_idx = start_idx + self.out_features_per_rank end_idx = start_idx + self.out_features_per_rank

View File

@ -11,18 +11,21 @@ from typing import Callable, List, Optional
def get_current_device(): def get_current_device():
return os.environ["LOCAL_DEVICE"] return os.environ["LOCAL_DEVICE"]
def get_world_size() -> int: def get_world_size() -> int:
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
return dist.get_world_size() return dist.get_world_size()
else: else:
return 1 return 1
def get_rank() -> int: def get_rank() -> int:
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
return dist.get_rank() return dist.get_rank()
else: else:
return 0 return 0
@contextmanager @contextmanager
def setup_parallel( def setup_parallel(
rank: int, rank: int,
@ -31,7 +34,7 @@ def setup_parallel(
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: str = "29500", master_port: str = "29500",
device_type: str = "cuda", device_type: str = "cuda",
device_ids: Optional[List[int]] = None device_ids: Optional[List[int]] = None,
): ):
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
@ -48,24 +51,21 @@ def setup_parallel(
rank = device_ids[rank % len(device_ids)] rank = device_ids[rank % len(device_ids)]
device_id = torch.device(device_type, device_ids[rank]) device_id = torch.device(device_type, device_ids[rank])
os.environ['MASTER_ADDR'] = master_addr os.environ["MASTER_ADDR"] = master_addr
os.environ['MASTER_PORT'] = master_port os.environ["MASTER_PORT"] = master_port
os.environ['LOCAL_RANK'] = str(rank) os.environ["LOCAL_RANK"] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size) os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_DEVICE"] = str(device_id) os.environ["LOCAL_DEVICE"] = str(device_id)
dist.init_process_group( dist.init_process_group(
rank=rank, rank=rank, world_size=world_size, backend=backend, device_id=device_id
world_size=world_size,
backend=backend,
device_id=device_id
) )
try: try:
if backend == "nccl" and torch.cuda.is_available(): if backend == "nccl" and torch.cuda.is_available():
torch.cuda.set_device(device_id) torch.cuda.set_device(device_id)
elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available(): elif backend == "ccl" and hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.set_device(device_id) torch.xpu.set_device(device_id)
yield dist.group.WORLD yield dist.group.WORLD
@ -73,6 +73,7 @@ def setup_parallel(
if dist.is_initialized(): if dist.is_initialized():
dist.destroy_process_group() dist.destroy_process_group()
def only_on_rank(rank, sync=False): def only_on_rank(rank, sync=False):
""" """
decorator to run a function only on a specific rank. decorator to run a function only on a specific rank.
@ -94,6 +95,7 @@ def only_on_rank(rank, sync=False):
return decorator return decorator
def wrapper_spawn_func( def wrapper_spawn_func(
rank: int, rank: int,
world_size: int, world_size: int,
@ -103,7 +105,7 @@ def wrapper_spawn_func(
device_type: str, device_type: str,
device_ids: List[int], device_ids: List[int],
func: Callable, func: Callable,
kwargs: dict kwargs: dict,
): ):
try: try:
with setup_parallel( with setup_parallel(
@ -113,7 +115,7 @@ def wrapper_spawn_func(
master_addr=master_addr, master_addr=master_addr,
master_port=master_port, master_port=master_port,
device_type=device_type, device_type=device_type,
device_ids=device_ids device_ids=device_ids,
): ):
func(**kwargs) func(**kwargs)
@ -121,6 +123,7 @@ def wrapper_spawn_func(
print(f"Error in rank {rank}: {e}") print(f"Error in rank {rank}: {e}")
raise raise
def spawn_parallel_fn( def spawn_parallel_fn(
func: Callable, func: Callable,
world_size: int, world_size: int,
@ -129,10 +132,17 @@ def spawn_parallel_fn(
master_port: str = "29500", master_port: str = "29500",
device_type: str = "cuda", device_type: str = "cuda",
device_ids: Optional[List[int]] = None, device_ids: Optional[List[int]] = None,
**kwargs **kwargs,
): ):
# clear environment variables # clear environment variables
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK', 'LOCAL_DEVICE']: for key in [
"MASTER_ADDR",
"MASTER_PORT",
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
"LOCAL_DEVICE",
]:
if key in os.environ: if key in os.environ:
del os.environ[key] del os.environ[key]
@ -144,12 +154,17 @@ def spawn_parallel_fn(
func(**kwargs) func(**kwargs)
return return
wrapper_spawn_func_args = (world_size, backend, master_addr, master_port, wrapper_spawn_func_args = (
device_type, device_ids, func, kwargs) world_size,
backend,
master_addr,
master_port,
device_type,
device_ids,
func,
kwargs,
)
mp.spawn( mp.spawn(
wrapper_spawn_func, wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True
nprocs=world_size,
args=wrapper_spawn_func_args,
join=True
) )

View File

@ -14,15 +14,12 @@ from khaosz.trainer.train_callback import (
__all__ = [ __all__ = [
# Main trainer # Main trainer
"Trainer", "Trainer",
# Strategy factory # Strategy factory
"StrategyFactory", "StrategyFactory",
"BaseStrategy", "BaseStrategy",
# Scheduler factory # Scheduler factory
"SchedulerFactory", "SchedulerFactory",
"BaseScheduler", "BaseScheduler",
# Callbacks # Callbacks
"TrainCallback", "TrainCallback",
"GradientClippingCallback", "GradientClippingCallback",

View File

@ -1,8 +1,9 @@
import torch.nn as nn import torch.nn as nn
from typing import Dict from typing import Dict
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]: def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
""" Compute gradient norm for each parameter in the model. """ """Compute gradient norm for each parameter in the model."""
norms = {} norms = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
norms[name] = 0.0 norms[name] = 0.0
@ -11,8 +12,9 @@ def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
norms[name] = norm norms[name] = norm
return norms return norms
def grad_std(model: nn.Module) -> Dict[str, float]: def grad_std(model: nn.Module) -> Dict[str, float]:
""" Compute standard deviation of gradients for each parameter. """ """Compute standard deviation of gradients for each parameter."""
stds = {} stds = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
stds[name] = 0.0 stds[name] = 0.0
@ -21,30 +23,33 @@ def grad_std(model: nn.Module) -> Dict[str, float]:
stds[name] = std stds[name] = std
return stds return stds
def grad_max(model: nn.Module) -> Dict[str, float]: def grad_max(model: nn.Module) -> Dict[str, float]:
""" Find the maximum absolute gradient value for each parameter. """ """Find the maximum absolute gradient value for each parameter."""
max_vals = {} max_vals = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
max_vals[name] = -float('inf') max_vals[name] = -float("inf")
if param.grad: if param.grad:
max_val = param.grad.data.max().item() max_val = param.grad.data.max().item()
max_vals[name] = max_val max_vals[name] = max_val
return max_vals return max_vals
def grad_min(model: nn.Module) -> Dict[str, float]: def grad_min(model: nn.Module) -> Dict[str, float]:
""" Find the minimum absolute gradient value for each parameter. """ """Find the minimum absolute gradient value for each parameter."""
min_vals = {} min_vals = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
min_vals[name] = float('inf') min_vals[name] = float("inf")
if param.grad: if param.grad:
min_val = param.grad.data.min().item() min_val = param.grad.data.min().item()
min_vals[name] = min_val min_vals[name] = min_val
return min_vals return min_vals
def grad_mean(model: nn.Module) -> Dict[str, float]: def grad_mean(model: nn.Module) -> Dict[str, float]:
""" Compute mean of gradients for each parameter. """ """Compute mean of gradients for each parameter."""
means = {} means = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
means[name] = 0.0 means[name] = 0.0
@ -54,8 +59,9 @@ def grad_mean(model: nn.Module) -> Dict[str, float]:
return means return means
def grad_nan_num(model: nn.Module) -> Dict[str, int]: def grad_nan_num(model: nn.Module) -> Dict[str, int]:
""" Count the number of NaNs in gradients for each parameter. """ """Count the number of NaNs in gradients for each parameter."""
nan_nums = {} nan_nums = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
nan_nums[name] = 0 nan_nums[name] = 0
@ -64,26 +70,34 @@ def grad_nan_num(model: nn.Module) -> Dict[str, int]:
nan_nums[name] = nan_num nan_nums[name] = nan_num
return nan_nums return nan_nums
def ctx_get_loss(ctx): def ctx_get_loss(ctx):
return ctx.loss return ctx.loss
def ctx_get_lr(ctx): def ctx_get_lr(ctx):
return ctx.optimizer.param_groups[-1]['lr'] return ctx.optimizer.param_groups[-1]["lr"]
def ctx_get_grad_norm(ctx): def ctx_get_grad_norm(ctx):
return grad_norm(ctx.model) return grad_norm(ctx.model)
def ctx_get_grad_std(ctx): def ctx_get_grad_std(ctx):
return grad_std(ctx.model) return grad_std(ctx.model)
def ctx_get_grad_max(ctx): def ctx_get_grad_max(ctx):
return grad_max(ctx.model) return grad_max(ctx.model)
def ctx_get_grad_min(ctx): def ctx_get_grad_min(ctx):
return grad_min(ctx.model) return grad_min(ctx.model)
def ctx_get_grad_mean(ctx): def ctx_get_grad_mean(ctx):
return grad_mean(ctx.model) return grad_mean(ctx.model)
def ctx_get_grad_nan_num(ctx): def ctx_get_grad_nan_num(ctx):
return grad_nan_num(ctx.model) return grad_nan_num(ctx.model)

View File

@ -55,11 +55,15 @@ class SchedulerFactory:
Returns: Returns:
Decorator function that registers the scheduler class Decorator function that registers the scheduler class
""" """
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]: def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
if not issubclass(scheduler_cls, BaseScheduler): if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler") raise TypeError(
f"{scheduler_cls.__name__} must inherit from BaseScheduler"
)
cls.SCHEDULER_MAP[name] = scheduler_cls cls.SCHEDULER_MAP[name] = scheduler_cls
return scheduler_cls return scheduler_cls
return decorator return decorator
@classmethod @classmethod
@ -121,7 +125,7 @@ class CosineScheduler(BaseScheduler):
warmup_steps: int, warmup_steps: int,
lr_decay_steps: int, lr_decay_steps: int,
min_rate: float = 0.05, min_rate: float = 0.05,
last_epoch: int = -1 last_epoch: int = -1,
): ):
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
self.lr_decay_steps = lr_decay_steps self.lr_decay_steps = lr_decay_steps
@ -129,7 +133,6 @@ class CosineScheduler(BaseScheduler):
self.total_steps = warmup_steps + lr_decay_steps self.total_steps = warmup_steps + lr_decay_steps
super().__init__(optimizer, last_epoch) super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]: def get_lr(self) -> List[float]:
# warmup # warmup
if self.last_epoch < self.warmup_steps: if self.last_epoch < self.warmup_steps:
@ -145,19 +148,21 @@ class CosineScheduler(BaseScheduler):
def state_dict(self): def state_dict(self):
state = super().state_dict() state = super().state_dict()
state.update({ state.update(
'warmup_steps': self.warmup_steps, {
'lr_decay_steps': self.lr_decay_steps, "warmup_steps": self.warmup_steps,
'min_rate': self.min_rate, "lr_decay_steps": self.lr_decay_steps,
'total_steps': self.total_steps, "min_rate": self.min_rate,
}) "total_steps": self.total_steps,
}
)
return state return state
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
self.warmup_steps = state_dict.pop('warmup_steps') self.warmup_steps = state_dict.pop("warmup_steps")
self.lr_decay_steps = state_dict.pop('lr_decay_steps') self.lr_decay_steps = state_dict.pop("lr_decay_steps")
self.min_rate = state_dict.pop('min_rate') self.min_rate = state_dict.pop("min_rate")
self.total_steps = state_dict.pop('total_steps') self.total_steps = state_dict.pop("total_steps")
super().load_state_dict(state_dict) super().load_state_dict(state_dict)
@ -181,7 +186,6 @@ class SGDRScheduler(BaseScheduler):
super().__init__(optimizer, last_epoch) super().__init__(optimizer, last_epoch)
def get_lr(self): def get_lr(self):
# warmup # warmup
if self.last_epoch < self.warmup_steps: if self.last_epoch < self.warmup_steps:
@ -204,7 +208,9 @@ class SGDRScheduler(BaseScheduler):
steps_in_cycle = steps_since_warmup - total_cycles_length steps_in_cycle = steps_since_warmup - total_cycles_length
# 2. Cosine annealing within the current cycle # 2. Cosine annealing within the current cycle
cosine_factor = 0.5 * (1 + math.cos(math.pi * steps_in_cycle / current_cycle_length)) cosine_factor = 0.5 * (
1 + math.cos(math.pi * steps_in_cycle / current_cycle_length)
)
learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor
return [base_lr * learning_rate_factor for base_lr in self.base_lrs] return [base_lr * learning_rate_factor for base_lr in self.base_lrs]
@ -212,18 +218,20 @@ class SGDRScheduler(BaseScheduler):
def state_dict(self): def state_dict(self):
"""Returns the state of the scheduler as a dict.""" """Returns the state of the scheduler as a dict."""
state = super().state_dict() state = super().state_dict()
state.update({ state.update(
'warmup_steps': self.warmup_steps, {
'cycle_length': self.cycle_length, "warmup_steps": self.warmup_steps,
'min_rate': self.min_rate, "cycle_length": self.cycle_length,
't_mult': self.t_mult "min_rate": self.min_rate,
}) "t_mult": self.t_mult,
}
)
return state return state
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
"""Loads the scheduler's state.""" """Loads the scheduler's state."""
self.warmup_steps = state_dict.pop('warmup_steps') self.warmup_steps = state_dict.pop("warmup_steps")
self.cycle_length = state_dict.pop('cycle_length') self.cycle_length = state_dict.pop("cycle_length")
self.min_rate = state_dict.pop('min_rate') self.min_rate = state_dict.pop("min_rate")
self.t_mult = state_dict.pop('t_mult') self.t_mult = state_dict.pop("t_mult")
super().load_state_dict(state_dict) super().load_state_dict(state_dict)

View File

@ -55,7 +55,9 @@ def get_logprobs(
""" """
allowed_reductions = ["mean", "sum", "none"] allowed_reductions = ["mean", "sum", "none"]
if reduction not in allowed_reductions: if reduction not in allowed_reductions:
raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'") raise ValueError(
f"reduction must be one of {allowed_reductions}, got '{reduction}'"
)
shifted_input_ids = input_ids[:, 1:] shifted_input_ids = input_ids[:, 1:]
shifted_mask = mask[:, 1:] shifted_mask = mask[:, 1:]
@ -64,13 +66,13 @@ def get_logprobs(
log_probs = torch.log_softmax(logits.float(), dim=-1) log_probs = torch.log_softmax(logits.float(), dim=-1)
token_logprobs = torch.gather( token_logprobs = torch.gather(
log_probs, log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
dim=-1,
index=shifted_input_ids.unsqueeze(-1)
).squeeze(-1) ).squeeze(-1)
if reduction == "mean": if reduction == "mean":
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(dim=-1).clamp(min=1.0) return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(
dim=-1
).clamp(min=1.0)
elif reduction == "sum": elif reduction == "sum":
return (token_logprobs * shifted_mask).sum(dim=-1) return (token_logprobs * shifted_mask).sum(dim=-1)
else: else:
@ -80,7 +82,9 @@ def get_logprobs(
class BaseStrategy(ABC): class BaseStrategy(ABC):
"""Abstract base class for training strategies.""" """Abstract base class for training strategies."""
def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str): def __init__(
self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str
):
self.model = model self.model = model
self.device = device self.device = device
@ -128,11 +132,15 @@ class StrategyFactory:
Returns: Returns:
Decorator function that registers the strategy class Decorator function that registers the strategy class
""" """
def decorator(strategy_cls: type) -> type: def decorator(strategy_cls: type) -> type:
if not issubclass(strategy_cls, BaseStrategy): if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy") raise TypeError(
f"{strategy_cls.__name__} must inherit from BaseStrategy"
)
cls.STRATEGY_MAP[name] = strategy_cls cls.STRATEGY_MAP[name] = strategy_cls
return strategy_cls return strategy_cls
return decorator return decorator
@classmethod @classmethod
@ -195,7 +203,7 @@ class SEQStrategy(BaseStrategy):
loss = F.cross_entropy( loss = F.cross_entropy(
input=logits.flatten(0, 1).float(), input=logits.flatten(0, 1).float(),
target=target_ids.flatten(), target=target_ids.flatten(),
label_smoothing=self.label_smoothing label_smoothing=self.label_smoothing,
) )
return loss return loss
@ -214,7 +222,11 @@ class SFTStrategy(BaseStrategy):
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device) batch = move_to_device(batch, self.device)
input_ids, target_ids, loss_mask = batch["input_ids"], batch["target_ids"], batch["loss_mask"] input_ids, target_ids, loss_mask = (
batch["input_ids"],
batch["target_ids"],
batch["loss_mask"],
)
ignore_index = -100 ignore_index = -100
logits = self.model(input_ids=input_ids)["logits"] logits = self.model(input_ids=input_ids)["logits"]
@ -224,7 +236,7 @@ class SFTStrategy(BaseStrategy):
input=logits.flatten(0, 1).float(), input=logits.flatten(0, 1).float(),
target=target_ids.flatten(), target=target_ids.flatten(),
ignore_index=ignore_index, ignore_index=ignore_index,
label_smoothing=self.label_smoothing label_smoothing=self.label_smoothing,
) )
return loss return loss
@ -239,12 +251,12 @@ class DPOStrategy(BaseStrategy):
""" """
def __init__( def __init__(
self, self,
model: nn.Module, model: nn.Module,
device: str, device: str,
beta: float = 0.1, beta: float = 0.1,
reduction: str = "mean", reduction: str = "mean",
): ):
super().__init__(model, device) super().__init__(model, device)
self.ref_model = create_ref_model(model) self.ref_model = create_ref_model(model)
self.beta = beta self.beta = beta
@ -261,12 +273,14 @@ class DPOStrategy(BaseStrategy):
log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction) log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction)
with torch.no_grad(): with torch.no_grad():
log_ref = get_logprobs(self.ref_model, contact_ids, contact_mask, self.reduction) log_ref = get_logprobs(
self.ref_model, contact_ids, contact_mask, self.reduction
)
log_pi_chosen = log_pi[:chosen_ids.shape[0]] log_pi_chosen = log_pi[: chosen_ids.shape[0]]
log_pi_rejected = log_pi[chosen_ids.shape[0]:] log_pi_rejected = log_pi[chosen_ids.shape[0] :]
log_ref_chosen = log_ref[:chosen_ids.shape[0]] log_ref_chosen = log_ref[: chosen_ids.shape[0]]
log_ref_rejected = log_ref[chosen_ids.shape[0]:] log_ref_rejected = log_ref[chosen_ids.shape[0] :]
pi_log_ratio = log_pi_chosen - log_pi_rejected pi_log_ratio = log_pi_chosen - log_pi_rejected
ref_log_ratio = log_ref_chosen - log_ref_rejected ref_log_ratio = log_ref_chosen - log_ref_rejected
@ -316,11 +330,15 @@ class GRPOStrategy(BaseStrategy):
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1) full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1) full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
log_probs_policy = get_logprobs(self.model, full_sequences, full_masks, self.reduction) log_probs_policy = get_logprobs(
self.model, full_sequences, full_masks, self.reduction
)
log_probs_policy = log_probs_policy.view(batch_size, group_size) log_probs_policy = log_probs_policy.view(batch_size, group_size)
with torch.no_grad(): with torch.no_grad():
log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction) log_probs_ref = get_logprobs(
self.ref_model, full_sequences, full_masks, self.reduction
)
log_probs_ref = log_probs_ref.view(batch_size, group_size) log_probs_ref = log_probs_ref.view(batch_size, group_size)
# Compute advantages from rewards with normalization # Compute advantages from rewards with normalization

View File

@ -18,7 +18,7 @@ from khaosz.trainer.metric_util import (
ctx_get_grad_norm, ctx_get_grad_norm,
ctx_get_grad_mean, ctx_get_grad_mean,
ctx_get_grad_std, ctx_get_grad_std,
ctx_get_grad_nan_num ctx_get_grad_nan_num,
) )
from khaosz.data.serialization import Checkpoint from khaosz.data.serialization import Checkpoint
from khaosz.trainer.train_context import TrainContext from khaosz.trainer.train_context import TrainContext
@ -30,37 +30,38 @@ class TrainCallback(Protocol):
""" """
def on_train_begin(self, context: TrainContext): def on_train_begin(self, context: TrainContext):
""" Called at the beginning of training. """ """Called at the beginning of training."""
def on_train_end(self, context: TrainContext): def on_train_end(self, context: TrainContext):
""" Called at the end of training. """ """Called at the end of training."""
def on_epoch_begin(self, context: TrainContext): def on_epoch_begin(self, context: TrainContext):
""" Called at the beginning of each epoch. """ """Called at the beginning of each epoch."""
def on_epoch_end(self, context: TrainContext): def on_epoch_end(self, context: TrainContext):
""" Called at the end of each epoch. """ """Called at the end of each epoch."""
def on_step_begin(self, context: TrainContext): def on_step_begin(self, context: TrainContext):
""" Called at the beginning of each step. """ """Called at the beginning of each step."""
def on_step_end(self, context: TrainContext): def on_step_end(self, context: TrainContext):
""" Called at the end of each step.""" """Called at the end of each step."""
def on_batch_begin(self, context: TrainContext): def on_batch_begin(self, context: TrainContext):
""" Called at the beginning of each batch. """ """Called at the beginning of each batch."""
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
""" Called at the end of each batch. """ """Called at the end of each batch."""
def on_error(self, context: TrainContext): def on_error(self, context: TrainContext):
""" Called when an error occurs during training. """ """Called when an error occurs during training."""
class GradientClippingCallback(TrainCallback): class GradientClippingCallback(TrainCallback):
""" """
Gradient clipping callback for trainer. Gradient clipping callback for trainer.
""" """
def __init__(self, max_grad_norm: float): def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
@ -73,6 +74,7 @@ class SchedulerCallback(TrainCallback):
""" """
Scheduler callback for trainer. Scheduler callback for trainer.
""" """
def __init__(self): def __init__(self):
pass pass
@ -90,12 +92,13 @@ class CheckpointCallback(TrainCallback):
""" """
Checkpoint callback for trainer. Checkpoint callback for trainer.
""" """
def __init__( def __init__(
self, self,
save_dir: str, save_dir: str,
interval: int, interval: int,
weight_only: bool = False, weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
): ):
self.save_dir = save_dir self.save_dir = save_dir
self.interval = interval self.interval = interval
@ -105,13 +108,17 @@ class CheckpointCallback(TrainCallback):
@only_on_rank(0) @only_on_rank(0)
def _save_checkpoint(self, context: TrainContext): def _save_checkpoint(self, context: TrainContext):
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}") save_path = os.path.join(
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict() self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
state_dict = (
self.state_dict_fn(context.model)
if self.state_dict_fn
else context.model.state_dict()
)
context.checkpoint = Checkpoint( context.checkpoint = Checkpoint(
state_dict=state_dict, state_dict=state_dict, epoch=context.epoch, iteration=context.iteration
epoch=context.epoch,
iteration=context.iteration
) )
context.checkpoint.save(save_path) context.checkpoint.save(save_path)
@ -133,6 +140,7 @@ class ProgressBarCallback(TrainCallback):
""" """
Progress bar callback for trainer. Progress bar callback for trainer.
""" """
def __init__(self, num_epoch: int): def __init__(self, num_epoch: int):
self.num_epoch = num_epoch self.num_epoch = num_epoch
self.progress_bar: tqdm = None self.progress_bar: tqdm = None
@ -141,16 +149,18 @@ class ProgressBarCallback(TrainCallback):
def on_epoch_begin(self, context: TrainContext): def on_epoch_begin(self, context: TrainContext):
self.progress_bar = tqdm( self.progress_bar = tqdm(
context.dataloader, context.dataloader,
desc=f"Epoch {context.epoch+1}/{self.num_epoch}", desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
dynamic_ncols=True dynamic_ncols=True,
) )
@only_on_rank(0) @only_on_rank(0)
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
self.progress_bar.set_postfix({ self.progress_bar.set_postfix(
"loss": f"{context.loss:.4f}", {
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}" "loss": f"{context.loss:.4f}",
}) "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
}
)
self.progress_bar.update(1) self.progress_bar.update(1)
@only_on_rank(0) @only_on_rank(0)
@ -163,15 +173,15 @@ class ProgressBarCallback(TrainCallback):
class MetricLoggerCallback(TrainCallback): class MetricLoggerCallback(TrainCallback):
def __init__( def __init__(
self, self,
log_dir:str, log_dir: str,
save_interval:int, save_interval: int,
log_interval:int=10, log_interval: int = 10,
metrics:List[str]=None metrics: List[str] = None,
): ):
self.last_log_iter = 0 self.last_log_iter = 0
self.save_interval = save_interval self.save_interval = save_interval
self.log_interval = log_interval self.log_interval = log_interval
self.metrics = metrics or ['loss', 'lr'] self.metrics = metrics or ["loss", "lr"]
self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs" self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs"
self.log_dir.mkdir(parents=True, exist_ok=True) self.log_dir.mkdir(parents=True, exist_ok=True)
@ -179,22 +189,22 @@ class MetricLoggerCallback(TrainCallback):
self.log_cache = [] self.log_cache = []
self._metric_funcs = { self._metric_funcs = {
'loss': ctx_get_loss, "loss": ctx_get_loss,
'lr': ctx_get_lr, "lr": ctx_get_lr,
'grad_norm': ctx_get_grad_norm, "grad_norm": ctx_get_grad_norm,
'grad_std': ctx_get_grad_std, "grad_std": ctx_get_grad_std,
'grad_max': ctx_get_grad_max, "grad_max": ctx_get_grad_max,
'grad_min': ctx_get_grad_min, "grad_min": ctx_get_grad_min,
'grad_mean': ctx_get_grad_mean, "grad_mean": ctx_get_grad_mean,
'grad_nan_num': ctx_get_grad_nan_num "grad_nan_num": ctx_get_grad_nan_num,
} }
def _get_log_data(self, context: TrainContext): def _get_log_data(self, context: TrainContext):
return { return {
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'), "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"epoch": context.epoch, "epoch": context.epoch,
"iter": context.iteration, "iter": context.iteration,
**{m: self._metric_funcs[m](context) for m in self.metrics} **{m: self._metric_funcs[m](context) for m in self.metrics},
} }
@only_on_rank(0) @only_on_rank(0)
@ -205,9 +215,9 @@ class MetricLoggerCallback(TrainCallback):
def _save_log(self, epoch, iter): def _save_log(self, epoch, iter):
log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl" log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl"
with open(log_file, 'w') as f: with open(log_file, "w") as f:
for log in self.log_cache: for log in self.log_cache:
f.write(json.dumps(log) + '\n') f.write(json.dumps(log) + "\n")
def on_batch_end(self, context): def on_batch_end(self, context):
if context.iteration % self.log_interval == 0: if context.iteration % self.log_interval == 0:
@ -224,4 +234,3 @@ class MetricLoggerCallback(TrainCallback):
def on_error(self, context): def on_error(self, context):
self._save_log(context.epoch, context.iteration) self._save_log(context.epoch, context.iteration)

View File

@ -72,7 +72,7 @@ class TrainContextBuilder:
data_source=config.dataset, data_source=config.dataset,
start_epoch=self._context.epoch, start_epoch=self._context.epoch,
start_iter=sampler_offset, start_iter=sampler_offset,
seed=config.random_seed seed=config.random_seed,
) )
dataloader = DataLoader( dataloader = DataLoader(
@ -81,7 +81,7 @@ class TrainContextBuilder:
sampler=resumeable_sampler, sampler=resumeable_sampler,
num_workers=config.num_workers, num_workers=config.num_workers,
pin_memory=config.pin_memory, pin_memory=config.pin_memory,
prefetch_factor=config.prefetch_factor prefetch_factor=config.prefetch_factor,
) )
self._context.dataloader = dataloader self._context.dataloader = dataloader
return self return self
@ -91,7 +91,7 @@ class TrainContextBuilder:
model=self._context.model, model=self._context.model,
train_type=self.config.strategy, train_type=self.config.strategy,
device=get_current_device(), device=get_current_device(),
**self.config.extra_kwargs **self.config.extra_kwargs,
) )
return self return self

View File

@ -7,7 +7,7 @@ from khaosz.trainer.train_callback import (
CheckpointCallback, CheckpointCallback,
MetricLoggerCallback, MetricLoggerCallback,
GradientClippingCallback, GradientClippingCallback,
SchedulerCallback SchedulerCallback,
) )
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
from khaosz.data.serialization import Checkpoint from khaosz.data.serialization import Checkpoint
@ -18,30 +18,32 @@ logger = logging.getLogger(__name__)
class Trainer: class Trainer:
def __init__( def __init__(
self, self, train_config: TrainConfig, callbacks: Optional[List[TrainCallback]] = None
train_config: TrainConfig,
callbacks: Optional[List[TrainCallback]] = None
): ):
self.train_config = train_config self.train_config = train_config
default_callbacks = self._get_default_callbacks() default_callbacks = self._get_default_callbacks()
self.callbacks = default_callbacks + callbacks if callbacks else default_callbacks self.callbacks = (
default_callbacks + callbacks if callbacks else default_callbacks
)
def _get_default_callbacks(self) -> List[TrainCallback]: def _get_default_callbacks(self) -> List[TrainCallback]:
train_config = self.train_config train_config = self.train_config
return [ return [
ProgressBarCallback(train_config.n_epoch), ProgressBarCallback(train_config.n_epoch),
CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), CheckpointCallback(train_config.ckpt_dir, train_config.ckpt_interval),
MetricLoggerCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), MetricLoggerCallback(train_config.ckpt_dir, train_config.ckpt_interval),
GradientClippingCallback(train_config.max_grad_norm), GradientClippingCallback(train_config.max_grad_norm),
SchedulerCallback(), SchedulerCallback(),
] ]
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (TrainContextBuilder(self.train_config) return (
.with_checkpoint(checkpoint) TrainContextBuilder(self.train_config)
.with_dataloader() .with_checkpoint(checkpoint)
.with_strategy() .with_dataloader()
.build()) .with_strategy()
.build()
)
def _call_callbacks(self, method_name: str, context: TrainContext): def _call_callbacks(self, method_name: str, context: TrainContext):
for callback in self.callbacks: for callback in self.callbacks:
@ -59,23 +61,23 @@ class Trainer:
master_port=config.master_port, master_port=config.master_port,
device_type=config.device_type, device_type=config.device_type,
device_ids=config.device_ids, device_ids=config.device_ids,
checkpoint=checkpoint checkpoint=checkpoint,
) )
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint: def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
context = self._build_context(checkpoint) context = self._build_context(checkpoint)
self._call_callbacks('on_train_begin', context) self._call_callbacks("on_train_begin", context)
try: try:
context.model.train() context.model.train()
# 1.epoch # 1.epoch
for epoch in range(context.epoch, self.train_config.n_epoch): for epoch in range(context.epoch, self.train_config.n_epoch):
context.epoch = epoch context.epoch = epoch
self._call_callbacks('on_epoch_begin', context) self._call_callbacks("on_epoch_begin", context)
for batch in context.dataloader: for batch in context.dataloader:
# 3. batch # 3. batch
self._call_callbacks('on_batch_begin', context) self._call_callbacks("on_batch_begin", context)
loss = context.strategy(batch) loss = context.strategy(batch)
context.loss = loss.item() context.loss = loss.item()
context.iteration += 1 context.iteration += 1
@ -84,20 +86,20 @@ class Trainer:
stand_loss = loss / self.train_config.accumulation_steps stand_loss = loss / self.train_config.accumulation_steps
stand_loss.backward() stand_loss.backward()
self._call_callbacks('on_batch_end', context) self._call_callbacks("on_batch_end", context)
if context.iteration % self.train_config.accumulation_steps == 0: if context.iteration % self.train_config.accumulation_steps == 0:
# 2. step # 2. step
self._call_callbacks('on_step_begin', context) self._call_callbacks("on_step_begin", context)
context.optimizer.step() context.optimizer.step()
context.optimizer.zero_grad() context.optimizer.zero_grad()
self._call_callbacks('on_step_end', context) self._call_callbacks("on_step_end", context)
self._call_callbacks('on_epoch_end', context) self._call_callbacks("on_epoch_end", context)
except Exception as e: except Exception as e:
logger.error(f"Training failed: {str(e)}", exc_info=True) logger.error(f"Training failed: {str(e)}", exc_info=True)
self._call_callbacks('on_error', context) self._call_callbacks("on_error", context)
raise raise
finally: finally:
self._call_callbacks('on_train_end', context) self._call_callbacks("on_train_end", context)

View File

@ -26,7 +26,7 @@ classifiers = [
urls = { Homepage = "https://github.com/ViperEkura/KHAOSZ" } urls = { Homepage = "https://github.com/ViperEkura/KHAOSZ" }
[project.optional-dependencies] [project.optional-dependencies]
dev = ["pytest==9.0.2"] dev = ["pytest==9.0.2", "ruff"]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
where = ["."] where = ["."]
@ -36,3 +36,12 @@ extra-index-url = "https://download.pytorch.org/whl/cu126"
[tool.setuptools.dynamic] [tool.setuptools.dynamic]
version = { attr = "khaosz.__version__" } version = { attr = "khaosz.__version__" }
[tool.ruff]
target-version = "py312"
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"

View File

@ -24,7 +24,7 @@ class RandomDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
return { return {
"input_ids": torch.randint(0, self.vocab_size, (self.max_length,)), "input_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
"target_ids": torch.randint(0, self.vocab_size, (self.max_length,)) "target_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
} }
@ -65,7 +65,7 @@ class EarlyStoppingDataset(Dataset):
return { return {
"input_ids": torch.randint(0, 1000, (64,)), "input_ids": torch.randint(0, 1000, (64,)),
"target_ids": torch.randint(0, 1000, (64,)) "target_ids": torch.randint(0, 1000, (64,)),
} }
@ -91,10 +91,10 @@ def base_test_env(request: pytest.FixtureRequest):
"dim_ffn": dim_ffn, "dim_ffn": dim_ffn,
"max_len": 1024, "max_len": 1024,
"n_layers": 4, "n_layers": 4,
"norm_eps": 1e-5 "norm_eps": 1e-5,
} }
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(config, f) json.dump(config, f)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
transformer_config = ModelConfig().load(config_path) transformer_config = ModelConfig().load(config_path)
@ -112,16 +112,19 @@ def base_test_env(request: pytest.FixtureRequest):
shutil.rmtree(test_dir) shutil.rmtree(test_dir)
@pytest.fixture @pytest.fixture
def random_dataset(): def random_dataset():
dataset = RandomDataset() dataset = RandomDataset()
yield dataset yield dataset
@pytest.fixture @pytest.fixture
def multi_turn_dataset(): def multi_turn_dataset():
dataset = MultiTurnDataset() dataset = MultiTurnDataset()
yield dataset yield dataset
@pytest.fixture @pytest.fixture
def early_stopping_dataset(): def early_stopping_dataset():
dataset = EarlyStoppingDataset() dataset = EarlyStoppingDataset()

View File

@ -7,6 +7,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
from khaosz.data.serialization import Checkpoint from khaosz.data.serialization import Checkpoint
from khaosz.parallel.setup import get_rank, spawn_parallel_fn from khaosz.parallel.setup import get_rank, spawn_parallel_fn
def test_single_process(): def test_single_process():
model = torch.nn.Linear(10, 5) model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3) optimizer = AdamW(model.parameters(), lr=1e-3)
@ -14,7 +15,6 @@ def test_single_process():
for epoch in range(3): for epoch in range(3):
for iteration in range(10): for iteration in range(10):
x = torch.randn(32, 10) x = torch.randn(32, 10)
y = torch.randn(32, 5) y = torch.randn(32, 5)
loss = model(x).mean() loss = model(x).mean()
@ -24,11 +24,7 @@ def test_single_process():
scheduler.step() scheduler.step()
checkpoint = Checkpoint( checkpoint = Checkpoint(state_dict=model.state_dict(), epoch=3, iteration=30)
state_dict=model.state_dict(),
epoch=3,
iteration=30
)
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir) checkpoint.save(tmpdir)
@ -37,6 +33,8 @@ def test_single_process():
assert loaded_checkpoint.epoch == 3 assert loaded_checkpoint.epoch == 3
assert loaded_checkpoint.iteration == 30 assert loaded_checkpoint.iteration == 30
def simple_training(): def simple_training():
model = torch.nn.Linear(10, 5) model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3) optimizer = AdamW(model.parameters(), lr=1e-3)
@ -66,19 +64,14 @@ def simple_training():
else: else:
shared_dir = None shared_dir = None
if dist.is_initialized(): if dist.is_initialized():
dir_list = [shared_dir] dir_list = [shared_dir]
dist.broadcast_object_list(dir_list, src=0) dist.broadcast_object_list(dir_list, src=0)
shared_dir = dir_list[0] shared_dir = dir_list[0]
loaded = Checkpoint.load(shared_dir) loaded = Checkpoint.load(shared_dir)
assert loaded.epoch == 2 assert loaded.epoch == 2
def test_multi_process(): def test_multi_process():
spawn_parallel_fn( spawn_parallel_fn(simple_training, world_size=2, backend="gloo")
simple_training,
world_size=2,
backend="gloo"
)

View File

@ -5,7 +5,6 @@ from khaosz.data.serialization import save_h5
from khaosz.data.dataset import * from khaosz.data.dataset import *
def test_dataset_loader_random_paths(base_test_env): def test_dataset_loader_random_paths(base_test_env):
"""Test dataset loader with multiple random paths""" """Test dataset loader with multiple random paths"""
test_dir = base_test_env["test_dir"] test_dir = base_test_env["test_dir"]
@ -16,7 +15,10 @@ def test_dataset_loader_random_paths(base_test_env):
for i in range(num_files): for i in range(num_files):
seq_length = np.random.randint(200, 400) seq_length = np.random.randint(200, 400)
dummy_data = { dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64) for _ in range(10)], "sequence": [
torch.randint(0, 1000, (seq_length,), dtype=torch.int64)
for _ in range(10)
],
} }
save_h5(test_dir, f"data_{i}", dummy_data) save_h5(test_dir, f"data_{i}", dummy_data)
@ -49,7 +51,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
"chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], "chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], "rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"chosen_mask": [torch.ones(seq_length, dtype=torch.bool)], "chosen_mask": [torch.ones(seq_length, dtype=torch.bool)],
"rejected_mask": [torch.ones(seq_length, dtype=torch.bool)] "rejected_mask": [torch.ones(seq_length, dtype=torch.bool)],
} }
save_h5(test_dir, "dpo_data", dummy_data) save_h5(test_dir, "dpo_data", dummy_data)
@ -62,7 +64,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
) )
assert dpo_dataset is not None assert dpo_dataset is not None
assert hasattr(dpo_dataset, 'fetcher') assert hasattr(dpo_dataset, "fetcher")
assert len(dpo_dataset) > 0 assert len(dpo_dataset) > 0
# Test that we can get DPO items without errors # Test that we can get DPO items without errors
@ -85,7 +87,7 @@ def test_sft_dataset_with_random_data(base_test_env):
dummy_data = { dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)] "loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
} }
save_h5(test_dir, "sft_data", dummy_data) save_h5(test_dir, "sft_data", dummy_data)
@ -98,7 +100,7 @@ def test_sft_dataset_with_random_data(base_test_env):
) )
assert sft_dataset is not None assert sft_dataset is not None
assert hasattr(sft_dataset, 'fetcher') assert hasattr(sft_dataset, "fetcher")
assert len(sft_dataset) > 0 assert len(sft_dataset) > 0
# Test that we can get SFT items without errors # Test that we can get SFT items without errors
@ -121,15 +123,12 @@ def test_dataset_with_custom_stride(base_test_env):
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
} }
save_h5(test_dir,"stride_test_data", dummy_data) save_h5(test_dir, "stride_test_data", dummy_data)
# Test with custom stride # Test with custom stride
custom_stride = 32 custom_stride = 32
dataset = DatasetLoader.load( dataset = DatasetLoader.load(
train_type="seq", train_type="seq", load_path=test_dir, window_size=64, stride=custom_stride
load_path=test_dir,
window_size=64,
stride=custom_stride
) )
assert dataset is not None assert dataset is not None

View File

@ -1,6 +1,7 @@
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.data import * from khaosz.data import *
def test_random_sampler_consistency(random_dataset): def test_random_sampler_consistency(random_dataset):
"""Test RandomSampler produces consistent results with same seed""" """Test RandomSampler produces consistent results with same seed"""
dataset = random_dataset dataset = random_dataset
@ -14,6 +15,7 @@ def test_random_sampler_consistency(random_dataset):
assert indices1 == indices2 assert indices1 == indices2
def test_random_sampler_different_seeds(random_dataset): def test_random_sampler_different_seeds(random_dataset):
"""Test RandomSampler produces different results with different seeds""" """Test RandomSampler produces different results with different seeds"""
dataset = random_dataset dataset = random_dataset

View File

@ -12,6 +12,7 @@ from khaosz.data import *
from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore
from tokenizers import pre_tokenizers from tokenizers import pre_tokenizers
@pytest.fixture @pytest.fixture
def test_env(request: pytest.FixtureRequest): def test_env(request: pytest.FixtureRequest):
func_name = request.function.__name__ func_name = request.function.__name__
@ -28,9 +29,9 @@ def test_env(request: pytest.FixtureRequest):
"dim_ffn": 256, "dim_ffn": 256,
"max_len": 64, "max_len": 64,
"n_layers": 2, "n_layers": 2,
"norm_eps": 1e-5 "norm_eps": 1e-5,
} }
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(config, f) json.dump(config, f)
tokenizer = BpeTokenizer() tokenizer = BpeTokenizer()
@ -51,30 +52,40 @@ def test_env(request: pytest.FixtureRequest):
shutil.rmtree(test_dir) shutil.rmtree(test_dir)
def test_model_parameter(test_env): def test_model_parameter(test_env):
save_dir = os.path.join(test_env["test_dir"], "save") save_dir = os.path.join(test_env["test_dir"], "save")
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"]) model_param = ModelParameter(
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
)
ModelParameter.save(model_param, save_dir) ModelParameter.save(model_param, save_dir)
assert os.path.exists(os.path.join(save_dir, "model.safetensors")) assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
assert os.path.exists(os.path.join(save_dir, "tokenizer.json")) assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))
assert os.path.exists(os.path.join(save_dir, "config.json")) assert os.path.exists(os.path.join(save_dir, "config.json"))
# transformer # transformer
def test_transformer(test_env): def test_transformer(test_env):
model = test_env["model"] model = test_env["model"]
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, input_ids = torch.randint(
(4, test_env["transformer_config"].max_len)) 0,
test_env["transformer_config"].vocab_size,
(4, test_env["transformer_config"].max_len),
)
output_logits = model(input_ids)["logits"] output_logits = model(input_ids)["logits"]
target_shape = (4, test_env["transformer_config"].max_len, test_env["transformer_config"].vocab_size) target_shape = (
4,
test_env["transformer_config"].max_len,
test_env["transformer_config"].vocab_size,
)
assert output_logits.shape == target_shape assert output_logits.shape == target_shape
# generator # generator
def test_embedding_encoder_core(test_env): def test_embedding_encoder_core(test_env):
parameter = ModelParameter( parameter = ModelParameter(
test_env["model"], test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
test_env["tokenizer"],
test_env["transformer_config"]
) )
encoder = EmbeddingEncoderCore(parameter) encoder = EmbeddingEncoderCore(parameter)
@ -82,16 +93,14 @@ def test_embedding_encoder_core(test_env):
assert isinstance(single_emb, torch.Tensor) assert isinstance(single_emb, torch.Tensor)
assert single_emb.shape[-1] == test_env["transformer_config"].dim assert single_emb.shape[-1] == test_env["transformer_config"].dim
batch_emb = encoder.encode(["测试1", "测试2"]) batch_emb = encoder.encode(["测试1", "测试2"])
assert isinstance(batch_emb, list) assert isinstance(batch_emb, list)
assert len(batch_emb) == 2 assert len(batch_emb) == 2
def test_generator_core(test_env): def test_generator_core(test_env):
parameter = ModelParameter( parameter = ModelParameter(
test_env["model"], test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
test_env["tokenizer"],
test_env["transformer_config"]
) )
generator = GeneratorCore(parameter) generator = GeneratorCore(parameter)
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10)) input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10))
@ -102,7 +111,7 @@ def test_generator_core(test_env):
top_p=0.95, top_p=0.95,
attn_mask=None, attn_mask=None,
kv_caches=None, kv_caches=None,
start_pos=0 start_pos=0,
) )
assert next_token_id.shape == (4, 1) assert next_token_id.shape == (4, 1)

View File

@ -22,17 +22,13 @@ def transformer_test_env():
"dim_ffn": 256, "dim_ffn": 256,
"max_len": 64, "max_len": 64,
"n_layers": 2, "n_layers": 2,
"norm_eps": 1e-5 "norm_eps": 1e-5,
} }
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(config, f) json.dump(config, f)
yield { yield {"test_dir": test_dir, "config_path": config_path, "config": config}
"test_dir": test_dir,
"config_path": config_path,
"config": config
}
if os.path.exists(test_dir): if os.path.exists(test_dir):
try: try:
@ -50,7 +46,7 @@ def test_tie_weight_init(transformer_test_env):
# case 1: tie weight # case 1: tie weight
config_data["tie_weight"] = True config_data["tie_weight"] = True
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig().load(config_path) config = ModelConfig().load(config_path)
@ -68,7 +64,7 @@ def test_tie_weight_init(transformer_test_env):
# case 2: not tie weight # case 2: not tie weight
config_data["tie_weight"] = False config_data["tie_weight"] = False
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig().load(config_path) config = ModelConfig().load(config_path)
@ -83,6 +79,7 @@ def test_tie_weight_init(transformer_test_env):
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert not torch.equal(model.lm_head.weight, original_weight) assert not torch.equal(model.lm_head.weight, original_weight)
def test_model_save_load_with_tie_weight(transformer_test_env): def test_model_save_load_with_tie_weight(transformer_test_env):
test_dir = transformer_test_env["test_dir"] test_dir = transformer_test_env["test_dir"]
model_path = os.path.join(test_dir, "model.safetensors") model_path = os.path.join(test_dir, "model.safetensors")
@ -93,7 +90,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
config_data["tie_weight"] = True config_data["tie_weight"] = True
config_path = os.path.join(test_dir, "config.json") config_path = os.path.join(test_dir, "config.json")
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig().load(config_path) config = ModelConfig().load(config_path)
@ -111,7 +108,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
# case 2: not tie weight (form tie-weight state dict load) # case 2: not tie weight (form tie-weight state dict load)
config_data["tie_weight"] = False config_data["tie_weight"] = False
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
loaded_config = ModelConfig().load(config_path) loaded_config = ModelConfig().load(config_path)
@ -121,4 +118,3 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr() assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
assert "lm_head.weight" in model.state_dict() assert "lm_head.weight" in model.state_dict()

View File

@ -1,16 +1,14 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from khaosz.parallel import ( from khaosz.parallel import get_rank, only_on_rank, spawn_parallel_fn
get_rank,
only_on_rank,
spawn_parallel_fn
)
@only_on_rank(0) @only_on_rank(0)
def _test_only_on_rank_helper(): def _test_only_on_rank_helper():
return True return True
def only_on_rank(): def only_on_rank():
result = _test_only_on_rank_helper() result = _test_only_on_rank_helper()
if get_rank() == 0: if get_rank() == 0:
@ -18,22 +16,17 @@ def only_on_rank():
else: else:
assert result is None assert result is None
def all_reduce(): def all_reduce():
x = torch.tensor([get_rank()], dtype=torch.int) x = torch.tensor([get_rank()], dtype=torch.int)
dist.all_reduce(x, op=dist.ReduceOp.SUM) dist.all_reduce(x, op=dist.ReduceOp.SUM)
expected_sum = sum(range(dist.get_world_size())) expected_sum = sum(range(dist.get_world_size()))
assert x.item() == expected_sum assert x.item() == expected_sum
def test_spawn_only_on_rank(): def test_spawn_only_on_rank():
spawn_parallel_fn( spawn_parallel_fn(only_on_rank, world_size=2, backend="gloo")
only_on_rank,
world_size=2,
backend="gloo"
)
def test_spawn_all_reduce(): def test_spawn_all_reduce():
spawn_parallel_fn( spawn_parallel_fn(all_reduce, world_size=2, backend="gloo")
all_reduce,
world_size=2,
backend="gloo"
)

View File

@ -3,57 +3,48 @@ import torch
from khaosz.config import * from khaosz.config import *
from khaosz.trainer import * from khaosz.trainer import *
def test_callback_integration(base_test_env, random_dataset): def test_callback_integration(base_test_env, random_dataset):
"""Test that all callbacks are properly integrated""" """Test that all callbacks are properly integrated"""
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
warmup_steps=10,
total_steps=20
)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig( train_config = TrainConfig(
model=base_test_env["model"], model=base_test_env["model"],
strategy='seq', strategy="seq",
dataset=random_dataset, dataset=random_dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
checkpoint_dir=base_test_env["test_dir"], ckpt_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=2, batch_size=2,
checkpoint_interval=3, ckpt_interval=3,
accumulation_steps=1, accumulation_steps=1,
max_grad_norm=1.0, max_grad_norm=1.0,
random_seed=42, random_seed=42,
device_type=base_test_env["device"] device_type=base_test_env["device"],
) )
# Create custom callbacks to track calls # Create custom callbacks to track calls
callback_calls = [] callback_calls = []
class TrackingCallback(TrainCallback): class TrackingCallback(TrainCallback):
def on_train_begin(self, context): def on_train_begin(self, context):
callback_calls.append('on_train_begin') callback_calls.append("on_train_begin")
def on_batch_end(self, context): def on_batch_end(self, context):
callback_calls.append('on_batch_end') callback_calls.append("on_batch_end")
def on_epoch_end(self, context): def on_epoch_end(self, context):
callback_calls.append('on_epoch_end') callback_calls.append("on_epoch_end")
trainer = Trainer(train_config, callbacks=[TrackingCallback()])
trainer = Trainer(
train_config,
callbacks=[TrackingCallback()]
)
trainer.train() trainer.train()
# Verify callbacks were called # Verify callbacks were called
assert 'on_train_begin' in callback_calls assert "on_train_begin" in callback_calls
assert 'on_batch_end' in callback_calls assert "on_batch_end" in callback_calls
assert 'on_epoch_end' in callback_calls assert "on_epoch_end" in callback_calls

View File

@ -5,6 +5,7 @@ from khaosz.config import *
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.data.serialization import Checkpoint from khaosz.data.serialization import Checkpoint
def test_early_stopping_simulation(base_test_env, early_stopping_dataset): def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
"""Simulate early stopping behavior""" """Simulate early stopping behavior"""
@ -19,13 +20,13 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
model=base_test_env["model"], model=base_test_env["model"],
dataset=early_stopping_dataset, dataset=early_stopping_dataset,
checkpoint_dir=base_test_env["test_dir"], ckpt_dir=base_test_env["test_dir"],
n_epoch=2, n_epoch=2,
batch_size=2, batch_size=2,
checkpoint_interval=1, ckpt_interval=1,
accumulation_steps=2, accumulation_steps=2,
random_seed=np.random.randint(1e4), random_seed=np.random.randint(1e4),
device_type=base_test_env["device"] device_type=base_test_env["device"],
) )
trainer = Trainer(train_config) trainer = Trainer(train_config)

View File

@ -20,14 +20,14 @@ def test_schedule_factory_random_configs():
CosineScheduleConfig( CosineScheduleConfig(
warmup_steps=np.random.randint(50, 200), warmup_steps=np.random.randint(50, 200),
total_steps=np.random.randint(1000, 5000), total_steps=np.random.randint(1000, 5000),
min_rate=np.random.uniform(0.01, 0.1) min_rate=np.random.uniform(0.01, 0.1),
), ),
SGDRScheduleConfig( SGDRScheduleConfig(
warmup_steps=np.random.randint(50, 200), warmup_steps=np.random.randint(50, 200),
cycle_length=np.random.randint(500, 2000), cycle_length=np.random.randint(500, 2000),
t_mult=np.random.randint(1, 3), t_mult=np.random.randint(1, 3),
min_rate=np.random.uniform(0.01, 0.1) min_rate=np.random.uniform(0.01, 0.1),
) ),
] ]
for config in schedule_configs: for config in schedule_configs:
@ -41,7 +41,9 @@ def test_schedule_factory_random_configs():
if isinstance(config, CosineScheduleConfig): if isinstance(config, CosineScheduleConfig):
assert isinstance(scheduler, CosineScheduler) assert isinstance(scheduler, CosineScheduler)
assert scheduler.warmup_steps == config.warmup_steps assert scheduler.warmup_steps == config.warmup_steps
assert scheduler.lr_decay_steps == config.total_steps - config.warmup_steps assert (
scheduler.lr_decay_steps == config.total_steps - config.warmup_steps
)
assert scheduler.min_rate == config.min_rate assert scheduler.min_rate == config.min_rate
elif isinstance(config, SGDRScheduleConfig): elif isinstance(config, SGDRScheduleConfig):
assert isinstance(scheduler, SGDRScheduler) assert isinstance(scheduler, SGDRScheduler)
@ -52,8 +54,8 @@ def test_schedule_factory_random_configs():
# Test scheduler state dict functionality # Test scheduler state dict functionality
state_dict = scheduler.state_dict() state_dict = scheduler.state_dict()
assert 'warmup_steps' in state_dict assert "warmup_steps" in state_dict
assert 'min_rate' in state_dict assert "min_rate" in state_dict
# Test scheduler step functionality # Test scheduler step functionality
initial_lr = scheduler.get_last_lr() initial_lr = scheduler.get_last_lr()

View File

@ -6,15 +6,13 @@ from khaosz.config import *
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.data.dataset import * from khaosz.data.dataset import *
def test_different_batch_sizes(base_test_env, random_dataset): def test_different_batch_sizes(base_test_env, random_dataset):
"""Test training with different batch sizes""" """Test training with different batch sizes"""
batch_sizes = [1, 2, 4, 8] batch_sizes = [1, 2, 4, 8]
for batch_size in batch_sizes: for batch_size in batch_sizes:
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
warmup_steps=10,
total_steps=20
)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
@ -24,27 +22,25 @@ def test_different_batch_sizes(base_test_env, random_dataset):
dataset=random_dataset, dataset=random_dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
checkpoint_dir=base_test_env["test_dir"], ckpt_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=batch_size, batch_size=batch_size,
checkpoint_interval=5, ckpt_interval=5,
accumulation_steps=1, accumulation_steps=1,
max_grad_norm=1.0, max_grad_norm=1.0,
random_seed=np.random.randint(1000), random_seed=np.random.randint(1000),
device_type=base_test_env["device"] device_type=base_test_env["device"],
) )
assert train_config.batch_size == batch_size assert train_config.batch_size == batch_size
def test_gradient_accumulation(base_test_env, random_dataset): def test_gradient_accumulation(base_test_env, random_dataset):
"""Test training with different gradient accumulation steps""" """Test training with different gradient accumulation steps"""
accumulation_steps_list = [1, 2, 4] accumulation_steps_list = [1, 2, 4]
for accumulation_steps in accumulation_steps_list: for accumulation_steps in accumulation_steps_list:
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
warmup_steps=10,
total_steps=20
)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
@ -54,14 +50,14 @@ def test_gradient_accumulation(base_test_env, random_dataset):
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
dataset=random_dataset, dataset=random_dataset,
checkpoint_dir=base_test_env["test_dir"], ckpt_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=2, batch_size=2,
checkpoint_interval=10, ckpt_interval=10,
accumulation_steps=accumulation_steps, accumulation_steps=accumulation_steps,
max_grad_norm=1.0, max_grad_norm=1.0,
random_seed=42, random_seed=42,
device_type=base_test_env["device"] device_type=base_test_env["device"],
) )
trainer = Trainer(train_config) trainer = Trainer(train_config)
@ -69,20 +65,18 @@ def test_gradient_accumulation(base_test_env, random_dataset):
assert train_config.accumulation_steps == accumulation_steps assert train_config.accumulation_steps == accumulation_steps
def test_memory_efficient_training(base_test_env, random_dataset): def test_memory_efficient_training(base_test_env, random_dataset):
"""Test training with memory-efficient configurations""" """Test training with memory-efficient configurations"""
# Test with smaller batch sizes and gradient checkpointing # Test with smaller batch sizes and gradient checkpointing
small_batch_configs = [ small_batch_configs = [
{"batch_size": 1, "accumulation_steps": 8}, {"batch_size": 1, "accumulation_steps": 8},
{"batch_size": 2, "accumulation_steps": 4}, {"batch_size": 2, "accumulation_steps": 4},
{"batch_size": 4, "accumulation_steps": 2} {"batch_size": 4, "accumulation_steps": 2},
] ]
for config in small_batch_configs: for config in small_batch_configs:
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
warmup_steps=10,
total_steps=20
)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
@ -92,14 +86,14 @@ def test_memory_efficient_training(base_test_env, random_dataset):
dataset=random_dataset, dataset=random_dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
checkpoint_dir=base_test_env["test_dir"], ckpt_dir=base_test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=config["batch_size"], batch_size=config["batch_size"],
checkpoint_interval=5, ckpt_interval=5,
accumulation_steps=config["accumulation_steps"], accumulation_steps=config["accumulation_steps"],
max_grad_norm=1.0, max_grad_norm=1.0,
random_seed=42, random_seed=42,
device_type=base_test_env["device"] device_type=base_test_env["device"],
) )
assert train_config.accumulation_steps == config["accumulation_steps"] assert train_config.accumulation_steps == config["accumulation_steps"]

View File

@ -17,7 +17,7 @@ class GenerationBenchmark:
self, self,
config: ModelConfig, config: ModelConfig,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.float16 dtype: torch.dtype = torch.float16,
): ):
self.config = config self.config = config
self.device = device self.device = device
@ -28,7 +28,13 @@ class GenerationBenchmark:
def _initialize_kv_cache(self, batch_size: int) -> list: def _initialize_kv_cache(self, batch_size: int) -> list:
"""初始化KV缓存""" """初始化KV缓存"""
config = self.config config = self.config
shape = (batch_size, config.max_len, config.n_layers, config.n_kv_heads, config.dim // config.n_heads) shape = (
batch_size,
config.max_len,
config.n_layers,
config.n_kv_heads,
config.dim // config.n_heads,
)
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
return (k_cache, v_cache) return (k_cache, v_cache)
@ -39,7 +45,7 @@ class GenerationBenchmark:
high=self.config.vocab_size, high=self.config.vocab_size,
size=(batch_size, prompt_length), size=(batch_size, prompt_length),
device=self.device, device=self.device,
dtype=torch.long dtype=torch.long,
) )
gen_ids = torch.randint( gen_ids = torch.randint(
@ -47,7 +53,7 @@ class GenerationBenchmark:
high=self.config.vocab_size, high=self.config.vocab_size,
size=(batch_size, total_length - prompt_length), size=(batch_size, total_length - prompt_length),
device=self.device, device=self.device,
dtype=torch.long dtype=torch.long,
) )
return prompt_ids, gen_ids return prompt_ids, gen_ids
@ -61,7 +67,9 @@ class GenerationBenchmark:
) -> BenchmarkResult: ) -> BenchmarkResult:
for _ in range(3): for _ in range(3):
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length) prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length
)
_ = self.model(prompt_ids) _ = self.model(prompt_ids)
torch.cuda.synchronize() torch.cuda.synchronize()
@ -70,7 +78,9 @@ class GenerationBenchmark:
total_tokens = batch_size * prompt_length * num_trials total_tokens = batch_size * prompt_length * num_trials
for trial in range(num_trials): for trial in range(num_trials):
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length) prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length
)
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
@ -82,8 +92,10 @@ class GenerationBenchmark:
trial_time = start_event.elapsed_time(end_event) / 1000 trial_time = start_event.elapsed_time(end_event) / 1000
total_time += trial_time total_time += trial_time
print(f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s " print(
f"({prompt_length / trial_time:.1f} tokens/s)") f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
f"({prompt_length / trial_time:.1f} tokens/s)"
)
return BenchmarkResult( return BenchmarkResult(
total_tokens=total_tokens, total_tokens=total_tokens,
@ -95,7 +107,7 @@ class GenerationBenchmark:
"prompt_length": prompt_length, "prompt_length": prompt_length,
"dtype": self.dtype, "dtype": self.dtype,
"device": self.device, "device": self.device,
} },
) )
@torch.inference_mode() @torch.inference_mode()
@ -111,8 +123,9 @@ class GenerationBenchmark:
total_tokens = batch_size * gen_length * num_trials total_tokens = batch_size * gen_length * num_trials
for trial in range(num_trials): for trial in range(num_trials):
prompt_ids, gen_ids = self._prepare_inputs(
prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length) batch_size, prompt_length, prompt_length + gen_length
)
kv_cache = self._initialize_kv_cache(batch_size) kv_cache = self._initialize_kv_cache(batch_size)
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0) _ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
@ -125,8 +138,10 @@ class GenerationBenchmark:
current_pos = prompt_length current_pos = prompt_length
for i in range(gen_length): for i in range(gen_length):
input_token = gen_ids[:, i:i+1] input_token = gen_ids[:, i : i + 1]
_ = self.model(input_token, persistent_key_values=kv_cache, start_pos=current_pos) _ = self.model(
input_token, persistent_key_values=kv_cache, start_pos=current_pos
)
current_pos += 1 current_pos += 1
end_event.record() end_event.record()
@ -135,9 +150,10 @@ class GenerationBenchmark:
trial_time = start_event.elapsed_time(end_event) / 1000 trial_time = start_event.elapsed_time(end_event) / 1000
total_time += trial_time total_time += trial_time
print(f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s " print(
f"({gen_length / trial_time:.1f} tokens/s)") f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
f"({gen_length / trial_time:.1f} tokens/s)"
)
return BenchmarkResult( return BenchmarkResult(
total_tokens=total_tokens, total_tokens=total_tokens,
@ -150,7 +166,7 @@ class GenerationBenchmark:
"gen_length": gen_length, "gen_length": gen_length,
"dtype": self.dtype, "dtype": self.dtype,
"device": self.device, "device": self.device,
} },
) )
@ -164,9 +180,13 @@ def print_benchmark_result(result: BenchmarkResult):
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s") print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
if benchmark_type == "prefill": if benchmark_type == "prefill":
print(f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}") print(
f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}"
)
elif benchmark_type == "decoding": elif benchmark_type == "decoding":
print(f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}") print(
f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}"
)
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}") print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
print("-" * 80) print("-" * 80)
@ -190,9 +210,12 @@ if __name__ == "__main__":
print("Running Transformer Generation Benchmark") print("Running Transformer Generation Benchmark")
print("=" * 80) print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark(batch_size=4, prompt_length=512, num_trials=5) prefill_result = benchmark.run_prefill_benchmark(
batch_size=4, prompt_length=512, num_trials=5
)
print_benchmark_result(prefill_result) print_benchmark_result(prefill_result)
gen_result = benchmark.run_decoding_benchmark(batch_size=4, prompt_length=512, gen_length=128, num_trials=5) gen_result = benchmark.run_decoding_benchmark(
batch_size=4, prompt_length=512, gen_length=128, num_trials=5
)
print_benchmark_result(gen_result) print_benchmark_result(gen_result)

View File

@ -21,10 +21,10 @@ def processor(
with disable_random_init(): with disable_random_init():
param = ModelParameter.load(model_dir) param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16) param.to(device="cuda", dtype=torch.bfloat16)
generator = BatchGenerator(param) generator = BatchGenerator(param)
with open(input_json_file, "r", encoding='utf-8') as f: with open(input_json_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f] input_data = [json.loads(line) for line in f]
queries = [item[question_key] for item in input_data] queries = [item[question_key] for item in input_data]
@ -41,24 +41,60 @@ def processor(
responses = generator.generate(request) responses = generator.generate(request)
with open(output_json_file, "w", encoding='utf-8') as f: with open(output_json_file, "w", encoding="utf-8") as f:
for query, response in zip(queries, responses): for query, response in zip(queries, responses):
output_item = {question_key: query, response_key: response} output_item = {question_key: query, response_key: response}
f.write(json.dumps(output_item, ensure_ascii=False) + '\n') f.write(json.dumps(output_item, ensure_ascii=False) + "\n")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.") parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.") parser.add_argument(
parser.add_argument("--input_json_file", type=str, required=True, help="Path to the input JSONL file.") "--model_dir", type=str, required=True, help="Path to the model directory."
parser.add_argument("--output_json_file", type=str, required=True, help="Path to the output JSONL file.") )
parser.add_argument("--question_key", type=str, default="question", help="Key for the question in the input JSON.") parser.add_argument(
parser.add_argument("--response_key", type=str, default="response", help="Key for the response in the output JSON.") "--input_json_file",
parser.add_argument("--temperature", type=float, default=0.60, help="Temperature for generating responses.") type=str,
parser.add_argument("--top_k", type=int, default=30, help="Top-k value for generating responses.") required=True,
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.") help="Path to the input JSONL file.",
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.") )
parser.add_argument(
"--output_json_file",
type=str,
required=True,
help="Path to the output JSONL file.",
)
parser.add_argument(
"--question_key",
type=str,
default="question",
help="Key for the question in the input JSON.",
)
parser.add_argument(
"--response_key",
type=str,
default="response",
help="Key for the response in the output JSON.",
)
parser.add_argument(
"--temperature",
type=float,
default=0.60,
help="Temperature for generating responses.",
)
parser.add_argument(
"--top_k", type=int, default=30, help="Top-k value for generating responses."
)
parser.add_argument(
"--top_p",
type=float,
default=0.95,
help="Top-p value for generating responses.",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for generating responses."
)
args = parser.parse_args() args = parser.parse_args()

View File

@ -11,59 +11,56 @@ from khaosz.inference.core import disable_random_init
def compute_perplexity( def compute_perplexity(
model: nn.Module, model: nn.Module,
input_ids: Tensor, input_ids: Tensor,
input_mask: Tensor, input_mask: Tensor,
) -> Tensor: ) -> Tensor:
""" """
Compute the perplexity of a batch of input sequences, Compute the perplexity of a batch of input sequences,
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))). where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
""" """
output = model(input_ids, input_mask) output = model(input_ids, input_mask)
logits = output["logits"] logits = output["logits"]
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size] shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1] shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1] shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
loss = F.cross_entropy( loss = F.cross_entropy(
shifted_logits.flatten(0, 1), shifted_logits.flatten(0, 1), shifted_input_ids.flatten(0, 1), reduction="none"
shifted_input_ids.flatten(0, 1),
reduction='none'
) )
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1] loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
loss = loss * shifted_mask loss = loss * shifted_mask
sentence_loss = (loss).sum(dim=1) / shifted_mask.sum(dim=1) sentence_loss = (loss).sum(dim=1) / shifted_mask.sum(dim=1)
perplexity = torch.exp(sentence_loss) # [batch_size] perplexity = torch.exp(sentence_loss) # [batch_size]
return perplexity return perplexity
def process_file( def process_file(
model_dir: str, model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
input_file: str,
output_file: str,
batch_size: int,
text_key: str
): ):
with disable_random_init(): with disable_random_init():
param = ModelParameter.load(model_dir) param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16) param.to(device="cuda", dtype=torch.bfloat16)
model = param.model model = param.model
tokenizer = param.tokenizer tokenizer = param.tokenizer
with open(input_file, "r", encoding='utf-8') as f: with open(input_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f] input_data = [json.loads(line) for line in f]
texts = [item[text_key] for item in input_data] texts = [item[text_key] for item in input_data]
encoded_texts = [tokenizer.encode(text) for text in texts] encoded_texts = [tokenizer.encode(text) for text in texts]
output_data = [] output_data = []
for i in tqdm(range(0, len(encoded_texts), batch_size), desc="Computing perplexity"): for i in tqdm(
batch_encoded = encoded_texts[i:i + batch_size] range(0, len(encoded_texts), batch_size), desc="Computing perplexity"
batch_texts = texts[i:i + batch_size] ):
batch_encoded = encoded_texts[i : i + batch_size]
batch_texts = texts[i : i + batch_size]
max_len = max(len(seq) for seq in batch_encoded) max_len = max(len(seq) for seq in batch_encoded)
padded_ids = [] padded_ids = []
masks = [] masks = []
@ -82,18 +79,31 @@ def process_file(
for text, ppl in zip(batch_texts, perplexity): for text, ppl in zip(batch_texts, perplexity):
output_data.append({text_key: text, "ppl": float(ppl.item())}) output_data.append({text_key: text, "ppl": float(ppl.item())})
with open(output_file, "w", encoding='utf-8') as f: with open(output_file, "w", encoding="utf-8") as f:
for item in output_data: for item in output_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n') f.write(json.dumps(item, ensure_ascii=False) + "\n")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.") parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.") parser.add_argument(
parser.add_argument("--input_file", type=str, required=True, help="Path to the input file.") "--model_dir", type=str, required=True, help="Path to the model directory."
parser.add_argument("--output_file", type=str, required=True, help="Path to the output file.") )
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation.") parser.add_argument(
parser.add_argument("--text_key", type=str, default="text", help="Key for the text field in the input data.") "--input_file", type=str, required=True, help="Path to the input file."
)
parser.add_argument(
"--output_file", type=str, required=True, help="Path to the output file."
)
parser.add_argument(
"--batch_size", type=int, default=4, help="Batch size for evaluation."
)
parser.add_argument(
"--text_key",
type=str,
default="text",
help="Key for the text field in the input data.",
)
args = parser.parse_args() args = parser.parse_args()
with torch.inference_mode(): with torch.inference_mode():

View File

@ -16,39 +16,129 @@ def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train the Transformer model.") parser = argparse.ArgumentParser(description="Train the Transformer model.")
parser.add_argument("--train_type", type=str, required=True, choices=["seq", "sft", "dpo"], help="Train type.") parser.add_argument(
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.") "--train_type",
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.") type=str,
required=True,
choices=["seq", "sft", "dpo"],
help="Train type.",
)
parser.add_argument(
"--data_root_path",
type=str,
required=True,
help="Path to the root directory of the dataset.",
)
parser.add_argument(
"--param_path",
type=str,
required=True,
help="Path to the model parameters or resume checkpoint.",
)
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.") parser.add_argument(
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.") "--n_epoch", type=int, default=1, help="Number of epochs to train."
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.") )
parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.") parser.add_argument(
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.") "--batch_size", type=int, default=1, help="Batch size for training."
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.") )
parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.") parser.add_argument(
parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.") "--accumulation_steps",
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.") type=int,
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") default=1,
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.") help="Number of iterations between each optimizer step.",
parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory") )
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.") parser.add_argument(
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.") "--warmup_steps",
type=int,
default=1000,
help="Number of iters between warnings.",
)
parser.add_argument(
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
)
parser.add_argument(
"--max_grad_norm",
type=float,
default=1.0,
help="Max gradient norm for clipping.",
)
parser.add_argument(
"--adamw_beta1",
type=float,
default=0.9,
help="Beta values for AdamW optimizer.",
)
parser.add_argument(
"--adamw_beta2",
type=float,
default=0.95,
help="Beta values for AdamW optimizer.",
)
parser.add_argument(
"--adamw_weight_decay",
type=float,
default=0.01,
help="Weight decay for AdamW optimizer.",
)
parser.add_argument(
"--random_seed", type=int, default=3407, help="Random seed for reproducibility."
)
parser.add_argument(
"--num_workers", type=int, default=4, help="Number of workers for data loading."
)
parser.add_argument(
"--no_pin_memory",
action="store_false",
dest="pin_memory",
help="Disable pin memory",
)
parser.add_argument(
"--window_size",
type=int,
default=None,
help="the max length of the input sequence.",
)
parser.add_argument(
"--stride", type=int, default=None, help="the step size of the input sequence."
)
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
parser.add_argument("--label_smoothing", type=float, default=0.1, help="cross_entropy function label smoothing parameter") parser.add_argument(
"--label_smoothing",
type=float,
default=0.1,
help="cross_entropy function label smoothing parameter",
)
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.") parser.add_argument(
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.") "--ckpt_interval",
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.") type=int,
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.") default=5000,
help="Number of iters between checkpoints.",
)
parser.add_argument(
"--ckpt_dir",
type=str,
default="checkpoint",
help="Directory to save checkpoints.",
)
parser.add_argument(
"--start_epoch", type=int, default=0, help="Start epoch for training."
)
parser.add_argument(
"--start_batch", type=int, default=0, help="Start batch for training."
)
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.") parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
parser.add_argument("--device_type", type=str, default="cuda", help="Device type to use.") parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use."
)
args = parser.parse_args() args = parser.parse_args()
return args return args
def ddp_wrap(model: nn.Module): def ddp_wrap(model: nn.Module):
local_rank = get_rank() local_rank = get_rank()
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16) model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
@ -56,16 +146,21 @@ def ddp_wrap(model: nn.Module):
model, model,
device_ids=[local_rank], device_ids=[local_rank],
output_device=local_rank, output_device=local_rank,
find_unused_parameters=False find_unused_parameters=False,
) )
return ddp_model return ddp_model
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer: def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), **kwargs) return optim.AdamW(model.parameters(), **kwargs)
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
def create_scheduler(
optimizer: optim.Optimizer, **kwargs
) -> optim.lr_scheduler.LRScheduler:
return SchedulerFactory.load(optimizer, **kwargs) return SchedulerFactory.load(optimizer, **kwargs)
def prepare_checkpoint(model: nn.Module) -> dict: def prepare_checkpoint(model: nn.Module) -> dict:
return model.module.state_dict() return model.module.state_dict()
@ -81,8 +176,8 @@ def train(
start_batch: int, start_batch: int,
accumulation_steps: int, accumulation_steps: int,
warmup_steps: int, warmup_steps: int,
checkpoint_interval: int, ckpt_interval: int,
checkpoint_dir: str, ckpt_dir: str,
dpo_beta: float, dpo_beta: float,
adamw_beta1: float, adamw_beta1: float,
adamw_beta2: float, adamw_beta2: float,
@ -107,16 +202,13 @@ def train(
model = parameter.model model = parameter.model
strategy_kwargs = { strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
"dpo_beta": dpo_beta,
"label_smoothing": label_smoothing
}
dataset = DatasetLoader.load( dataset = DatasetLoader.load(
train_type=train_type, train_type=train_type,
load_path=data_root_path, load_path=data_root_path,
window_size=window_size, window_size=window_size,
stride=stride stride=stride,
) )
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(
@ -124,10 +216,15 @@ def train(
total_steps=len(dataset) * n_epoch // (batch_size * nprocs), total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
) )
optimizer_fn = partial(create_optimizer, optimizer_fn = partial(
**{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay}) create_optimizer,
scheduler_fn = partial(create_scheduler, **{
**{"schedule_config": schedule_config}) "lr": max_lr,
"betas": (adamw_beta1, adamw_beta2),
"weight_decay": adamw_weight_decay,
},
)
scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config})
train_config = TrainConfig( train_config = TrainConfig(
model=model, model=model,
@ -135,12 +232,12 @@ def train(
dataset=dataset, dataset=dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
checkpoint_dir=checkpoint_dir, ckpt_dir=ckpt_dir,
n_epoch=n_epoch, n_epoch=n_epoch,
batch_size=batch_size, batch_size=batch_size,
start_epoch=start_epoch, start_epoch=start_epoch,
start_batch=start_batch, start_batch=start_batch,
checkpoint_interval=checkpoint_interval, ckpt_interval=ckpt_interval,
accumulation_steps=accumulation_steps, accumulation_steps=accumulation_steps,
max_grad_norm=max_grad_norm, max_grad_norm=max_grad_norm,
random_seed=random_seed, random_seed=random_seed,