style: 使用ruff 工具优化代码风格
This commit is contained in:
parent
345fd2f091
commit
426af2d75f
|
|
@ -0,0 +1,27 @@
|
|||
name: Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install .[dev]
|
||||
|
||||
- name: Check formatting with ruff
|
||||
run: |
|
||||
ruff format --check .
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
name: Spell Check
|
||||
on: [push, pull_request]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
spellcheck:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Check spelling in specific files
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
check_filenames: true
|
||||
only_warn: false
|
||||
path: "**/*.{md, py}"
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
name: Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install .[dev]
|
||||
|
||||
- name: Run tests with pytest
|
||||
run: |
|
||||
python -m pytest tests/ -v
|
||||
|
|
@ -8,5 +8,8 @@
|
|||
!*.py
|
||||
!*.md
|
||||
!*.png
|
||||
|
||||
!LICENSE
|
||||
!pyproject.toml
|
||||
!pyproject.toml
|
||||
!.github/workflows/lint.yml
|
||||
!.github/workflows/tests.yml
|
||||
16
README.md
16
README.md
|
|
@ -54,8 +54,8 @@ python train.py \
|
|||
--n_epoch=5 \
|
||||
--batch_size=8 \
|
||||
--max_lr=2e-4 \
|
||||
--checkpoint_interval=10000 \
|
||||
--checkpoint_dir=checkpoints
|
||||
--ckpt_interval=10000 \
|
||||
--ckpt_dir=checkpoints
|
||||
```
|
||||
|
||||
**Parameter Explanation:**
|
||||
|
|
@ -67,8 +67,8 @@ python train.py \
|
|||
- `--accumulation_steps`: Number of batches per training step
|
||||
- `--warmup_steps`: Warmup steps
|
||||
- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
|
||||
- `--checkpoint_interval`: Checkpoint saving interval
|
||||
- `--checkpoint_dir`: Checkpoint saving directory
|
||||
- `--ckpt_interval`: Checkpoint saving interval
|
||||
- `--ckpt_dir`: Checkpoint saving directory
|
||||
- `--resume_dir`: Resume training from specified path
|
||||
|
||||
|
||||
|
|
@ -191,8 +191,8 @@ python train.py \
|
|||
--n_epoch=5 \
|
||||
--batch_size=8 \
|
||||
--max_lr=2e-4 \
|
||||
--checkpoint_interval=10000 \
|
||||
--checkpoint_dir=checkpoints
|
||||
--ckpt_interval=10000 \
|
||||
--ckpt_dir=checkpoints
|
||||
```
|
||||
|
||||
**参数说明:**
|
||||
|
|
@ -204,8 +204,8 @@ python train.py \
|
|||
- `--accumulation_steps`: 每个训练步骤的 batch 数量
|
||||
- `--warmup_steps`: 预热步数(warmup steps)
|
||||
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
|
||||
- `--checkpoint_interval`: 检查点保存间隔
|
||||
- `--checkpoint_dir`: 检查点保存目录
|
||||
- `--ckpt_interval`: 检查点保存间隔
|
||||
- `--ckpt_dir`: 检查点保存目录
|
||||
- `--resume_dir`: 从指定路径恢复训练
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,13 +2,12 @@ import os
|
|||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(
|
||||
os.path.dirname(os.path.abspath(__file__)))
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
snapshot_download(
|
||||
repo_id="ViperEk/KHAOSZ",
|
||||
local_dir=os.path.join(PROJECT_ROOT, "params"),
|
||||
force_download=True
|
||||
)
|
||||
local_dir=os.path.join(PROJECT_ROOT, "params"),
|
||||
force_download=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,18 +5,18 @@ from khaosz.inference.core import disable_random_init
|
|||
from khaosz.inference.generator import LoopGenerator, GenerationRequest
|
||||
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(
|
||||
os.path.dirname(os.path.abspath(__file__)))
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def generate_text():
|
||||
|
||||
|
||||
with disable_random_init():
|
||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||
param = ModelParameter.load(model_dir)
|
||||
|
||||
param.to(device='cuda', dtype=torch.bfloat16)
|
||||
|
||||
param.to(device="cuda", dtype=torch.bfloat16)
|
||||
query = input(">> ")
|
||||
|
||||
|
||||
request = GenerationRequest(
|
||||
query=query,
|
||||
temperature=0.8,
|
||||
|
|
@ -28,8 +28,9 @@ def generate_text():
|
|||
)
|
||||
generator = LoopGenerator(param)
|
||||
response = generator.generate(request)
|
||||
|
||||
|
||||
print(response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_text()
|
||||
generate_text()
|
||||
|
|
|
|||
|
|
@ -4,18 +4,24 @@ from khaosz.config.param_config import ModelParameter
|
|||
from khaosz.inference.core import disable_random_init
|
||||
from khaosz.inference.generator import BatchGenerator, GenerationRequest
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(
|
||||
os.path.dirname(os.path.abspath(__file__)))
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def batch_generate():
|
||||
with disable_random_init():
|
||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||
param = ModelParameter.load(model_dir)
|
||||
|
||||
param.to(device='cuda', dtype=torch.bfloat16)
|
||||
param.to(device="cuda", dtype=torch.bfloat16)
|
||||
generator = BatchGenerator(param)
|
||||
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
|
||||
|
||||
inputs = [
|
||||
"你好",
|
||||
"请问什么是人工智能",
|
||||
"今天天气如何",
|
||||
"我感到焦虑, 请问我应该怎么办",
|
||||
"请问什么是显卡",
|
||||
]
|
||||
|
||||
request = GenerationRequest(
|
||||
query=inputs,
|
||||
temperature=0.8,
|
||||
|
|
@ -26,9 +32,10 @@ def batch_generate():
|
|||
system_prompt=None,
|
||||
)
|
||||
responses = generator.generate(request)
|
||||
|
||||
|
||||
for q, r in zip(inputs, responses):
|
||||
print((q, r))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_generate()
|
||||
batch_generate()
|
||||
|
|
|
|||
|
|
@ -5,16 +5,16 @@ from khaosz.inference.core import disable_random_init
|
|||
from khaosz.inference.generator import StreamGenerator, GenerationRequest
|
||||
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(
|
||||
os.path.dirname(os.path.abspath(__file__)))
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def chat():
|
||||
|
||||
|
||||
with disable_random_init():
|
||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||
param = ModelParameter.load(model_dir)
|
||||
|
||||
param.to(device='cuda', dtype=torch.bfloat16)
|
||||
param.to(device="cuda", dtype=torch.bfloat16)
|
||||
generator = StreamGenerator(param)
|
||||
|
||||
history = []
|
||||
|
|
@ -22,7 +22,7 @@ def chat():
|
|||
query = input(">> ")
|
||||
if query == "!exit":
|
||||
break
|
||||
|
||||
|
||||
request = GenerationRequest(
|
||||
query=query,
|
||||
temperature=0.8,
|
||||
|
|
@ -32,7 +32,7 @@ def chat():
|
|||
history=history,
|
||||
system_prompt=None,
|
||||
)
|
||||
|
||||
|
||||
response_size = 0
|
||||
full_response = ""
|
||||
for response in generator.generate(request):
|
||||
|
|
@ -40,10 +40,10 @@ def chat():
|
|||
print(response[response_size:], end="", flush=True)
|
||||
response_size = len(response)
|
||||
full_response = response
|
||||
|
||||
|
||||
# After generation, update history
|
||||
history.append((query, full_response.strip()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
chat()
|
||||
chat()
|
||||
|
|
|
|||
|
|
@ -6,41 +6,30 @@ from khaosz.config import (
|
|||
TrainConfig,
|
||||
)
|
||||
from khaosz.model.transformer import Transformer
|
||||
from khaosz.data import (
|
||||
DatasetLoader,
|
||||
BpeTokenizer
|
||||
)
|
||||
from khaosz.data import DatasetLoader, BpeTokenizer
|
||||
from khaosz.inference.generator import (
|
||||
GenerationRequest,
|
||||
LoopGenerator,
|
||||
StreamGenerator,
|
||||
BatchGenerator,
|
||||
EmbeddingEncoder,
|
||||
GeneratorFactory
|
||||
)
|
||||
from khaosz.trainer import (
|
||||
Trainer,
|
||||
StrategyFactory,
|
||||
SchedulerFactory
|
||||
GeneratorFactory,
|
||||
)
|
||||
from khaosz.trainer import Trainer, StrategyFactory, SchedulerFactory
|
||||
|
||||
__all__ = [
|
||||
"Transformer",
|
||||
|
||||
"ModelConfig",
|
||||
"TrainConfig",
|
||||
|
||||
"DatasetLoader",
|
||||
"BpeTokenizer",
|
||||
|
||||
"GenerationRequest",
|
||||
"LoopGenerator",
|
||||
"StreamGenerator",
|
||||
"BatchGenerator",
|
||||
"EmbeddingEncoder",
|
||||
"GeneratorFactory",
|
||||
|
||||
"Trainer",
|
||||
"StrategyFactory",
|
||||
"SchedulerFactory"
|
||||
]
|
||||
"SchedulerFactory",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from khaosz.config.model_config import ModelConfig
|
||||
from khaosz.config.param_config import BaseModelIO, ModelParameter
|
||||
from khaosz.config.schedule_config import (
|
||||
ScheduleConfig,
|
||||
CosineScheduleConfig,
|
||||
ScheduleConfig,
|
||||
CosineScheduleConfig,
|
||||
SGDRScheduleConfig,
|
||||
ScheduleConfigFactory
|
||||
ScheduleConfigFactory,
|
||||
)
|
||||
from khaosz.config.train_config import TrainConfig
|
||||
|
||||
|
|
@ -13,14 +13,12 @@ __all__ = [
|
|||
# Base I/O
|
||||
"BaseModelIO",
|
||||
"ModelParameter",
|
||||
|
||||
# Model configuration
|
||||
"ModelConfig",
|
||||
"TrainConfig",
|
||||
|
||||
# Schedule configuration
|
||||
"ScheduleConfig",
|
||||
"CosineScheduleConfig",
|
||||
"SGDRScheduleConfig",
|
||||
"ScheduleConfigFactory",
|
||||
]
|
||||
]
|
||||
|
|
|
|||
|
|
@ -14,30 +14,29 @@ class ModelConfig:
|
|||
norm_eps: Optional[float] = None
|
||||
dim_ffn: Optional[int] = None
|
||||
tie_weight: Optional[bool] = None
|
||||
|
||||
|
||||
# RoPE
|
||||
max_len: Optional[int] = None
|
||||
rope_theta: Optional[float] = None
|
||||
|
||||
|
||||
# GQA
|
||||
n_heads: Optional[int] = None
|
||||
n_kv_heads: Optional[int] = None
|
||||
use_qk_norm: Optional[bool] = None
|
||||
use_gated_attention: Optional[bool] = None
|
||||
|
||||
|
||||
|
||||
def load(self, config_path: str) -> Self:
|
||||
config = {}
|
||||
with open(config_path, 'r') as f:
|
||||
config.update(json.load(f))
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config.update(json.load(f))
|
||||
|
||||
for key, value in config.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def save(self, config_path: str):
|
||||
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
|
||||
with open(config_path, 'w') as f:
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
|
|
|
|||
|
|
@ -9,58 +9,57 @@ from khaosz.data.tokenizer import BpeTokenizer
|
|||
from khaosz.config.model_config import ModelConfig
|
||||
from khaosz.model.transformer import Transformer
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelIO:
|
||||
"""Base class for model I/O operations."""
|
||||
|
||||
|
||||
model: Optional[nn.Module] = field(
|
||||
default=None,
|
||||
metadata={"help": "Transformer model."}
|
||||
default=None, metadata={"help": "Transformer model."}
|
||||
)
|
||||
tokenizer: BpeTokenizer = field(
|
||||
default_factory=BpeTokenizer,
|
||||
metadata={"help": "Tokenizer for the model."}
|
||||
default_factory=BpeTokenizer, metadata={"help": "Tokenizer for the model."}
|
||||
)
|
||||
config: ModelConfig = field(
|
||||
default_factory=ModelConfig,
|
||||
metadata={"help": "Transformer model configuration."}
|
||||
metadata={"help": "Transformer model configuration."},
|
||||
)
|
||||
|
||||
|
||||
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
||||
"""Get standardized file paths for model components."""
|
||||
dir_path = Path(directory)
|
||||
return {
|
||||
"model": dir_path / "model.safetensors",
|
||||
"config": dir_path / "config.json",
|
||||
"tokenizer": dir_path / "tokenizer.json"
|
||||
"config": dir_path / "config.json",
|
||||
"tokenizer": dir_path / "tokenizer.json",
|
||||
}
|
||||
|
||||
|
||||
def save_components(self, save_dir: Union[str, Path]):
|
||||
"""Save core model components."""
|
||||
paths = self._get_file_paths(save_dir)
|
||||
paths["model"].parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
if self.model is not None:
|
||||
st.save_file(self.model.state_dict(), str(paths["model"]))
|
||||
self.config.save(str(paths["config"]))
|
||||
self.tokenizer.save(str(paths["tokenizer"]))
|
||||
|
||||
|
||||
def load_components(self, load_dir: Union[str, Path]) -> Self:
|
||||
"""Load core model components."""
|
||||
paths = self._get_file_paths(load_dir)
|
||||
|
||||
|
||||
self.config.load(str(paths["config"]))
|
||||
self.tokenizer.load(str(paths["tokenizer"]))
|
||||
|
||||
|
||||
if self.model is None:
|
||||
self.model = Transformer(self.config)
|
||||
|
||||
|
||||
if paths["model"].exists():
|
||||
state_dict = st.load_file(str(paths["model"]))
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def to(self, *args, **kwargs) -> "BaseModelIO":
|
||||
"""Move model to device."""
|
||||
if self.model is not None:
|
||||
|
|
@ -71,13 +70,12 @@ class BaseModelIO:
|
|||
@dataclass
|
||||
class ModelParameter(BaseModelIO):
|
||||
"""Container for model parameters with serialization capabilities."""
|
||||
|
||||
|
||||
@classmethod
|
||||
def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]):
|
||||
instance.save_components(save_dir)
|
||||
|
||||
|
||||
@classmethod
|
||||
def load(cls, load_dir: Union[str, Path]) -> "ModelParameter":
|
||||
instance = cls()
|
||||
return instance.load_components(load_dir)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,35 +6,35 @@ from dataclasses import dataclass, field
|
|||
@dataclass
|
||||
class ScheduleConfig(ABC):
|
||||
"""Base configuration class for learning rate schedulers.
|
||||
|
||||
|
||||
Provides common validation and interface for all schedule types.
|
||||
"""
|
||||
|
||||
|
||||
schedule_type: str = field(
|
||||
default="cosine",
|
||||
metadata={
|
||||
"help": "Type of learning rate schedule.",
|
||||
"choices": ["cosine", "sgdr"]
|
||||
}
|
||||
"help": "Type of learning rate schedule.",
|
||||
"choices": ["cosine", "sgdr"],
|
||||
},
|
||||
)
|
||||
warmup_steps: int = field(
|
||||
default=1000,
|
||||
metadata={"help": "Number of warmup steps."}
|
||||
default=1000, metadata={"help": "Number of warmup steps."}
|
||||
)
|
||||
min_rate: float = field(
|
||||
default=0.05,
|
||||
metadata={"help": "Minimum learning rate multiplier."}
|
||||
default=0.05, metadata={"help": "Minimum learning rate multiplier."}
|
||||
)
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_kwargs(self) -> Dict[str, Any]:
|
||||
"""Get configuration kwargs for scheduler creation."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters."""
|
||||
if self.warmup_steps < 0:
|
||||
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
|
||||
raise ValueError(
|
||||
f"warmup_steps must be non-negative, got {self.warmup_steps}"
|
||||
)
|
||||
if not 0 <= self.min_rate <= 1:
|
||||
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
|
||||
|
||||
|
|
@ -42,44 +42,43 @@ class ScheduleConfig(ABC):
|
|||
@dataclass
|
||||
class CosineScheduleConfig(ScheduleConfig):
|
||||
"""Cosine annealing learning rate schedule configuration."""
|
||||
|
||||
|
||||
total_steps: int = field(
|
||||
default=None,
|
||||
metadata={"help": "Total training steps for cosine schedule."}
|
||||
default=None, metadata={"help": "Total training steps for cosine schedule."}
|
||||
)
|
||||
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.schedule_type = "cosine"
|
||||
self.validate()
|
||||
|
||||
|
||||
def get_kwargs(self) -> Dict[str, Any]:
|
||||
if self.total_steps is None:
|
||||
raise ValueError("total_steps must be specified for cosine schedule")
|
||||
|
||||
|
||||
return {
|
||||
"schedule_type": self.schedule_type,
|
||||
"warmup_steps": self.warmup_steps,
|
||||
"lr_decay_steps": self.total_steps - self.warmup_steps,
|
||||
"min_rate": self.min_rate
|
||||
"min_rate": self.min_rate,
|
||||
}
|
||||
|
||||
|
||||
def validate(self) -> None:
|
||||
super().validate()
|
||||
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
|
||||
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
|
||||
raise ValueError(
|
||||
f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SGDRScheduleConfig(ScheduleConfig):
|
||||
"""Stochastic Gradient Descent with Warm Restarts schedule configuration."""
|
||||
|
||||
|
||||
cycle_length: int = field(
|
||||
default=1000,
|
||||
metadata={"help": "Length of the first cycle in steps."}
|
||||
default=1000, metadata={"help": "Length of the first cycle in steps."}
|
||||
)
|
||||
t_mult: int = field(
|
||||
default=2,
|
||||
metadata={"help": "Multiplier for cycle length growth."}
|
||||
t_mult: int = field(
|
||||
default=2, metadata={"help": "Multiplier for cycle length growth."}
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
|
@ -92,9 +91,9 @@ class SGDRScheduleConfig(ScheduleConfig):
|
|||
"warmup_steps": self.warmup_steps,
|
||||
"cycle_length": self.cycle_length,
|
||||
"min_rate": self.min_rate,
|
||||
"t_mult": self.t_mult
|
||||
"t_mult": self.t_mult,
|
||||
}
|
||||
|
||||
|
||||
def validate(self) -> None:
|
||||
super().validate()
|
||||
if self.cycle_length <= 0:
|
||||
|
|
@ -105,33 +104,33 @@ class SGDRScheduleConfig(ScheduleConfig):
|
|||
|
||||
class ScheduleConfigFactory:
|
||||
"""Factory class for creating ScheduleConfig instances.
|
||||
|
||||
|
||||
Supports both direct instantiation and factory creation methods.
|
||||
|
||||
|
||||
Example usage:
|
||||
# Direct creation
|
||||
config = CosineScheduleConfig(total_steps=10000)
|
||||
|
||||
|
||||
# Factory method
|
||||
config = ScheduleConfigFactory.create("cosine", total_steps=10000)
|
||||
"""
|
||||
|
||||
|
||||
CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = {
|
||||
"cosine": CosineScheduleConfig,
|
||||
"sgdr": SGDRScheduleConfig,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig:
|
||||
"""Create a schedule config instance.
|
||||
|
||||
|
||||
Args:
|
||||
schedule_type: Type of schedule ("cosine", "sgdr")
|
||||
**kwargs: Arguments passed to the config constructor
|
||||
|
||||
|
||||
Returns:
|
||||
ScheduleConfig instance
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If schedule_type is not supported
|
||||
"""
|
||||
|
|
@ -140,11 +139,11 @@ class ScheduleConfigFactory:
|
|||
f"Unknown schedule type: '{schedule_type}'. "
|
||||
f"Supported types: {sorted(cls.CONFIG_MAP.keys())}"
|
||||
)
|
||||
|
||||
|
||||
config_cls = cls.CONFIG_MAP[schedule_type]
|
||||
return config_cls(**kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
def available_types(cls) -> list:
|
||||
"""Return list of available schedule type names."""
|
||||
return list(cls.CONFIG_MAP.keys())
|
||||
return list(cls.CONFIG_MAP.keys())
|
||||
|
|
|
|||
|
|
@ -10,127 +10,92 @@ from typing import Callable, List, Optional
|
|||
@dataclass
|
||||
class TrainConfig:
|
||||
# basic setting
|
||||
model: nn.Module = field(
|
||||
default=None,
|
||||
metadata={"help": "Model for training."}
|
||||
)
|
||||
strategy: str = field(
|
||||
default=None,
|
||||
metadata={"help": "Training strategy."}
|
||||
)
|
||||
dataset: Dataset = field(
|
||||
default=None,
|
||||
metadata={"help": "Dataset for training."}
|
||||
)
|
||||
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
|
||||
strategy: str = field(default=None, metadata={"help": "Training strategy."})
|
||||
dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."})
|
||||
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optimizer factory for training."}
|
||||
default=None, metadata={"help": "Optimizer factory for training."}
|
||||
)
|
||||
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scheduler factory for training."}
|
||||
)
|
||||
n_epoch: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of epochs for training."}
|
||||
)
|
||||
batch_size: int = field(
|
||||
default=4,
|
||||
metadata={"help": "Batch size for training."}
|
||||
default=None, metadata={"help": "Scheduler factory for training."}
|
||||
)
|
||||
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
||||
batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
|
||||
accumulation_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of iterations between steps."}
|
||||
default=1, metadata={"help": "Number of iterations between steps."}
|
||||
)
|
||||
max_grad_norm: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Maximum gradient norm."}
|
||||
default=1.0, metadata={"help": "Maximum gradient norm."}
|
||||
)
|
||||
|
||||
|
||||
# checkpoint setting
|
||||
start_epoch: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Start epoch for training."}
|
||||
)
|
||||
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
|
||||
start_batch: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Start batch iteration for training."}
|
||||
default=0, metadata={"help": "Start batch iteration for training."}
|
||||
)
|
||||
checkpoint_dir: str = field(
|
||||
default="./checkpoint",
|
||||
metadata={"help": "Checkpoint directory."}
|
||||
ckpt_dir: str = field(
|
||||
default="./checkpoint", metadata={"help": "Checkpoint directory."}
|
||||
)
|
||||
checkpoint_interval: int = field(
|
||||
default=5000,
|
||||
metadata={"help": "Number of iterations between checkpoints."}
|
||||
ckpt_interval: int = field(
|
||||
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
||||
)
|
||||
|
||||
|
||||
# dataloader setting
|
||||
random_seed: int = field(
|
||||
default=3407,
|
||||
metadata={"help": "Random seed."}
|
||||
)
|
||||
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
|
||||
num_workers: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Number of workers for dataloader."}
|
||||
default=0, metadata={"help": "Number of workers for dataloader."}
|
||||
)
|
||||
prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Prefetch factor for dataloader."}
|
||||
default=None, metadata={"help": "Prefetch factor for dataloader."}
|
||||
)
|
||||
pin_memory: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Pin memory for dataloader."}
|
||||
default=False, metadata={"help": "Pin memory for dataloader."}
|
||||
)
|
||||
|
||||
|
||||
# distributed training
|
||||
nprocs: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of processes for distributed training."}
|
||||
default=1, metadata={"help": "Number of processes for distributed training."}
|
||||
)
|
||||
backend: str = field(
|
||||
default="nccl",
|
||||
metadata={"help": "Distributed training backend."}
|
||||
default="nccl", metadata={"help": "Distributed training backend."}
|
||||
)
|
||||
master_addr: str = field(
|
||||
default="localhost",
|
||||
metadata={"help": "Master address for distributed training."}
|
||||
metadata={"help": "Master address for distributed training."},
|
||||
)
|
||||
master_port: str = field(
|
||||
default="29500",
|
||||
metadata={"help": "Master port for distributed training."}
|
||||
default="29500", metadata={"help": "Master port for distributed training."}
|
||||
)
|
||||
parallel_wrapper: Optional[Callable] = field(
|
||||
default=None,
|
||||
metadata={"help": "Parallel function for training."}
|
||||
default=None, metadata={"help": "Parallel function for training."}
|
||||
)
|
||||
state_dict_fn: Optional[Callable] = field(
|
||||
default=None,
|
||||
metadata={"help": "Parallel function for state dict saving."}
|
||||
default=None, metadata={"help": "Parallel function for state dict saving."}
|
||||
)
|
||||
|
||||
# others
|
||||
device_ids: Optional[List[int]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Device ids for distributed training."}
|
||||
default=None, metadata={"help": "Device ids for distributed training."}
|
||||
)
|
||||
device_type: str = field(
|
||||
default="cuda",
|
||||
metadata={"help": "Device type for distributed training."}
|
||||
default="cuda", metadata={"help": "Device type for distributed training."}
|
||||
)
|
||||
extra_kwargs: dict = field(
|
||||
default_factory=dict,
|
||||
metadata={"help": "Other arguments."}
|
||||
default_factory=dict, metadata={"help": "Other arguments."}
|
||||
)
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
self.validate()
|
||||
|
||||
|
||||
def validate(self):
|
||||
required_fields = ["model", "strategy", "dataset", "optimizer_fn", "scheduler_fn"]
|
||||
|
||||
required_fields = [
|
||||
"model",
|
||||
"strategy",
|
||||
"dataset",
|
||||
"optimizer_fn",
|
||||
"scheduler_fn",
|
||||
]
|
||||
|
||||
for field_name in required_fields:
|
||||
if getattr(self, field_name) is None:
|
||||
raise ValueError(f"{field_name} is required.")
|
||||
|
||||
|
||||
|
|
@ -1,12 +1,12 @@
|
|||
from khaosz.data.dataset import (
|
||||
BaseDataset,
|
||||
SEQDataset,
|
||||
DPODataset,
|
||||
SFTDataset,
|
||||
BaseDataset,
|
||||
SEQDataset,
|
||||
DPODataset,
|
||||
SFTDataset,
|
||||
GRPODataset,
|
||||
MultiSegmentFetcher,
|
||||
DatasetLoader,
|
||||
DatasetFactory
|
||||
DatasetFactory,
|
||||
)
|
||||
|
||||
from khaosz.data.tokenizer import BpeTokenizer
|
||||
|
|
@ -15,21 +15,17 @@ from khaosz.data.sampler import ResumableDistributedSampler
|
|||
__all__ = [
|
||||
# Base classes
|
||||
"BaseDataset",
|
||||
|
||||
# Dataset implementations
|
||||
"SEQDataset",
|
||||
"SFTDataset",
|
||||
"DPODataset",
|
||||
"GRPODataset",
|
||||
|
||||
# Fetchers
|
||||
"MultiSegmentFetcher",
|
||||
|
||||
# Factory (DatasetLoader is alias for backward compatibility)
|
||||
"DatasetLoader",
|
||||
"DatasetFactory",
|
||||
|
||||
# Tokenizer and sampler
|
||||
"BpeTokenizer",
|
||||
"ResumableDistributedSampler"
|
||||
]
|
||||
"ResumableDistributedSampler",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -12,40 +12,42 @@ from typing import Callable, List, Dict, Literal, Optional, Union
|
|||
|
||||
class BaseSegmentFetcher:
|
||||
"""Fetches data segments across multiple tensor segments.
|
||||
|
||||
|
||||
Maintains cumulative lengths for efficient range queries across
|
||||
multiple discontinuous segments.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, segments: List[Tensor]):
|
||||
self.segments = segments
|
||||
self.cum_lengths = []
|
||||
|
||||
|
||||
total = 0
|
||||
for seg in segments:
|
||||
total += torch.numel(seg)
|
||||
self.cum_lengths.append(total)
|
||||
|
||||
|
||||
self.total_length = total
|
||||
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.total_length
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
"""Fetch data in the range [begin_idx, end_idx).
|
||||
|
||||
|
||||
Args:
|
||||
begin_idx: Starting index (inclusive)
|
||||
end_idx: Ending index (exclusive)
|
||||
|
||||
|
||||
Returns:
|
||||
Concatenated tensor of data in the specified range
|
||||
"""
|
||||
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
|
||||
if not (
|
||||
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
|
||||
):
|
||||
raise ValueError("begin_idx or end_idx out of bounds")
|
||||
if begin_idx >= end_idx:
|
||||
return torch.tensor([], dtype=torch.long)
|
||||
|
||||
|
||||
# Find segment boundaries for the range
|
||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
||||
|
|
@ -64,43 +66,44 @@ class BaseSegmentFetcher:
|
|||
|
||||
class MultiSegmentFetcher:
|
||||
"""Manages multiple segment fetchers for different data keys.
|
||||
|
||||
|
||||
Each key corresponds to a different type of data (e.g., "sequence", "mask").
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, muti_segments: Dict):
|
||||
self.muti_keys = list(muti_segments.keys())
|
||||
self.muti_fetchers = {
|
||||
key: BaseSegmentFetcher(segments)
|
||||
for key, segments in muti_segments.items()
|
||||
key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items()
|
||||
}
|
||||
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the minimum length across all fetchers."""
|
||||
len_list = [len(seg) for seg in self.muti_fetchers.values()]
|
||||
return min(len_list)
|
||||
|
||||
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
|
||||
|
||||
def key_fetch(
|
||||
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
|
||||
) -> Dict:
|
||||
"""Fetch data for specific keys.
|
||||
|
||||
|
||||
Args:
|
||||
begin_idx: Starting index
|
||||
end_idx: Ending index
|
||||
keys: Single key or list of keys to fetch
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary of tensors if multiple keys, single tensor if one key
|
||||
"""
|
||||
fetch_dict = {}
|
||||
fetch_dict = {}
|
||||
keys = [keys] if isinstance(keys, str) else keys
|
||||
|
||||
|
||||
for key in keys:
|
||||
fetcher = self.muti_fetchers[key]
|
||||
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
||||
fetch_dict[key] = fetch_tensor
|
||||
|
||||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||
"""Fetch all keys."""
|
||||
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
||||
|
|
@ -108,10 +111,10 @@ class MultiSegmentFetcher:
|
|||
|
||||
class BaseDataset(Dataset, ABC):
|
||||
"""Abstract base class for all dataset types.
|
||||
|
||||
|
||||
Implements common functionality for window-based data fetching.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__()
|
||||
self.segments = {}
|
||||
|
|
@ -122,38 +125,38 @@ class BaseDataset(Dataset, ABC):
|
|||
|
||||
def load(self, load_path: str):
|
||||
"""Load dataset from HDF5 file.
|
||||
|
||||
|
||||
Args:
|
||||
load_path: Path to the HDF5 data file
|
||||
"""
|
||||
self.segments = load_h5(load_path)
|
||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||
self.total_samples = len(self.fetcher)
|
||||
|
||||
|
||||
def get_index(self, index: int) -> tuple:
|
||||
"""Calculate begin and end indices for a sample.
|
||||
|
||||
|
||||
Args:
|
||||
index: Sample index
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (begin_idx, end_idx)
|
||||
"""
|
||||
assert self.total_samples > self.window_size
|
||||
|
||||
|
||||
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
|
||||
end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
|
||||
|
||||
|
||||
return begin_idx, end_idx
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
"""Get a single sample by index.
|
||||
|
||||
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def __len__(self) -> int:
|
||||
assert self.total_samples is not None
|
||||
if self.total_samples <= self.window_size:
|
||||
|
|
@ -163,48 +166,50 @@ class BaseDataset(Dataset, ABC):
|
|||
|
||||
class DatasetFactory:
|
||||
"""Factory class for creating dataset instances.
|
||||
|
||||
|
||||
Supports decorator-based registration for extensible dataset types.
|
||||
All default dataset types (seq, sft, dpo, grpo) are registered automatically
|
||||
when their classes are defined with the decorator.
|
||||
|
||||
|
||||
Example usage:
|
||||
@DatasetFactory.register("custom")
|
||||
class CustomDataset(BaseDataset):
|
||||
...
|
||||
|
||||
|
||||
dataset = DatasetFactory.create("custom", window_size, stride)
|
||||
"""
|
||||
|
||||
|
||||
SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"})
|
||||
DATASET_MAP: Dict[str, type] = {}
|
||||
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
"""Decorator to register a new dataset class.
|
||||
|
||||
|
||||
Args:
|
||||
name: Registration name for the dataset type
|
||||
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the dataset class
|
||||
"""
|
||||
|
||||
def decorator(dataset_cls: type) -> type:
|
||||
if not issubclass(dataset_cls, BaseDataset):
|
||||
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
||||
cls.DATASET_MAP[name] = dataset_cls
|
||||
return dataset_cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset:
|
||||
"""Create a dataset instance.
|
||||
|
||||
|
||||
Args:
|
||||
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
||||
window_size: Window size for data sampling
|
||||
stride: Stride between consecutive samples
|
||||
|
||||
|
||||
Returns:
|
||||
Dataset instance
|
||||
"""
|
||||
|
|
@ -213,36 +218,42 @@ class DatasetFactory:
|
|||
f"Unknown dataset type: '{train_type}'. "
|
||||
f"Supported types: {sorted(cls.SUPPORTED_TYPES)}"
|
||||
)
|
||||
|
||||
|
||||
if train_type not in cls.DATASET_MAP:
|
||||
raise NotImplementedError(
|
||||
f"Dataset type '{train_type}' is supported but not yet implemented."
|
||||
)
|
||||
|
||||
|
||||
dataset_cls = cls.DATASET_MAP[train_type]
|
||||
return dataset_cls(window_size, stride)
|
||||
|
||||
|
||||
@classmethod
|
||||
def load(cls, train_type: str, load_path: str, window_size: int, stride: Optional[int] = None) -> BaseDataset:
|
||||
def load(
|
||||
cls,
|
||||
train_type: str,
|
||||
load_path: str,
|
||||
window_size: int,
|
||||
stride: Optional[int] = None,
|
||||
) -> BaseDataset:
|
||||
"""Create and load a dataset in one step.
|
||||
|
||||
|
||||
Args:
|
||||
train_type: Type of training dataset
|
||||
load_path: Path to the data file
|
||||
window_size: Window size for data sampling
|
||||
stride: Stride between consecutive samples (default: same as window_size)
|
||||
|
||||
|
||||
Returns:
|
||||
Loaded dataset instance
|
||||
"""
|
||||
if stride is None:
|
||||
stride = window_size
|
||||
|
||||
|
||||
dataset = cls.create(train_type, window_size, stride)
|
||||
dataset.load(load_path)
|
||||
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
@classmethod
|
||||
def available_types(cls) -> list:
|
||||
"""Return list of registered dataset type names."""
|
||||
|
|
@ -256,46 +267,50 @@ class DatasetFactory:
|
|||
@DatasetFactory.register("seq")
|
||||
class SEQDataset(BaseDataset):
|
||||
"""Dataset for sequential next-token prediction training."""
|
||||
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
|
||||
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
||||
|
||||
|
||||
return {"input_ids": x, "target_ids": y}
|
||||
|
||||
|
||||
@DatasetFactory.register("sft")
|
||||
class SFTDataset(BaseDataset):
|
||||
"""Dataset for supervised fine-tuning with loss masking."""
|
||||
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
|
||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
|
||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool)
|
||||
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
|
||||
dtype=torch.long
|
||||
)
|
||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
|
||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
||||
|
||||
|
||||
@DatasetFactory.register("dpo")
|
||||
class DPODataset(BaseDataset):
|
||||
"""Dataset for Direct Preference Optimization training."""
|
||||
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
|
||||
|
|
@ -304,25 +319,34 @@ class DPODataset(BaseDataset):
|
|||
|
||||
def __getitem__(self, index: int):
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
|
||||
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
|
||||
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
|
||||
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool)
|
||||
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool)
|
||||
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
|
||||
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
||||
return {
|
||||
"chosen": chosen,
|
||||
"rejected": rejected,
|
||||
"chosen_mask": chosen_mask,
|
||||
"rejected_mask": rejected_mask,
|
||||
}
|
||||
|
||||
|
||||
@DatasetFactory.register("grpo")
|
||||
class GRPODataset(BaseDataset):
|
||||
"""Dataset for Group Relative Policy Optimization training."""
|
||||
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
|
|
@ -330,8 +354,13 @@ class GRPODataset(BaseDataset):
|
|||
responses = self._fetch_data(begin_idx, end_idx, "responses")
|
||||
masks = self._fetch_data(begin_idx, end_idx, "masks")
|
||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||
|
||||
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
|
||||
|
||||
return {
|
||||
"prompts": prompts,
|
||||
"responses": responses,
|
||||
"masks": masks,
|
||||
"rewards": rewards,
|
||||
}
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
|
|
|
|||
|
|
@ -7,45 +7,45 @@ from typing import Optional
|
|||
|
||||
class ResumableDistributedSampler(Sampler[int]):
|
||||
def __init__(
|
||||
self,
|
||||
self,
|
||||
data_source: Dataset,
|
||||
start_epoch: int=0,
|
||||
start_iter: int=0,
|
||||
seed: int=42,
|
||||
drop_last: bool=False,
|
||||
shuffle: bool=True,
|
||||
process_group: Optional[dist.ProcessGroup]=None,
|
||||
start_epoch: int = 0,
|
||||
start_iter: int = 0,
|
||||
seed: int = 42,
|
||||
drop_last: bool = False,
|
||||
shuffle: bool = True,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
):
|
||||
self.epoch = start_epoch
|
||||
self.iter = start_iter
|
||||
self.seed = seed
|
||||
self.num_samples = len(data_source)
|
||||
|
||||
|
||||
if process_group is not None:
|
||||
# input process group
|
||||
self.rank = dist.get_rank(process_group)
|
||||
self.num_replicas = dist.get_world_size(process_group)
|
||||
|
||||
|
||||
elif dist.is_available() and dist.is_initialized():
|
||||
# use default process group
|
||||
process_group = dist.group.WORLD
|
||||
self.rank = dist.get_rank()
|
||||
self.num_replicas = dist.get_world_size()
|
||||
|
||||
|
||||
else:
|
||||
# single process
|
||||
self.rank = 0
|
||||
self.num_replicas = 1
|
||||
|
||||
|
||||
self.drop_last = drop_last
|
||||
self.shuffle = shuffle
|
||||
|
||||
offset = 0 if drop_last else self.num_replicas - 1
|
||||
|
||||
offset = 0 if drop_last else self.num_replicas - 1
|
||||
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
||||
self.total_size = self.num_samples_per_replica * self.num_replicas
|
||||
|
||||
|
||||
self._indices = None
|
||||
|
||||
|
||||
def _get_indices(self):
|
||||
if self.shuffle:
|
||||
generator = torch.Generator()
|
||||
|
|
@ -53,26 +53,26 @@ class ResumableDistributedSampler(Sampler[int]):
|
|||
indices = torch.randperm(self.num_samples, generator=generator).tolist()
|
||||
else:
|
||||
indices = torch.arange(self.num_samples).tolist()
|
||||
|
||||
|
||||
if not self.drop_last and self.num_samples < self.total_size:
|
||||
padding_size = self.total_size - len(indices)
|
||||
indices += indices[:padding_size]
|
||||
|
||||
local_indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
|
||||
|
||||
local_indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
|
||||
self.iter = self.iter % self.num_samples_per_replica
|
||||
self._indices = local_indices[self.iter:]
|
||||
|
||||
self._indices = local_indices[self.iter :]
|
||||
|
||||
def __iter__(self):
|
||||
if self._indices is None:
|
||||
self._get_indices()
|
||||
|
||||
|
||||
for i in self._indices:
|
||||
self.iter += 1
|
||||
yield i
|
||||
|
||||
|
||||
self.epoch += 1
|
||||
self._indices = None
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples_per_replica
|
||||
return self.num_samples_per_replica
|
||||
|
|
|
|||
|
|
@ -10,24 +10,26 @@ from torch import Tensor
|
|||
from typing import Any, Dict, List
|
||||
from khaosz.parallel.setup import get_rank
|
||||
|
||||
|
||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
||||
with h5py.File(full_file_path, 'w') as f:
|
||||
with h5py.File(full_file_path, "w") as f:
|
||||
for key, tensors in tensor_group.items():
|
||||
grp = f.create_group(key)
|
||||
for idx, tensor in enumerate(tensors):
|
||||
arr = tensor.cpu().numpy()
|
||||
grp.create_dataset(f'data_{idx}', data=arr)
|
||||
grp.create_dataset(f"data_{idx}", data=arr)
|
||||
|
||||
|
||||
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||
tensor_group: Dict[str, List[Tensor]] = {}
|
||||
|
||||
root_path = Path(file_path)
|
||||
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||
|
||||
|
||||
for h5_file in h5_files:
|
||||
with h5py.File(h5_file, 'r') as f:
|
||||
with h5py.File(h5_file, "r") as f:
|
||||
for key in f.keys():
|
||||
grp = f[key]
|
||||
dsets = []
|
||||
|
|
@ -37,7 +39,7 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
|||
if share_memory:
|
||||
tensor = tensor.share_memory_()
|
||||
dsets.append(tensor)
|
||||
|
||||
|
||||
if tensor_group.get(key) is None:
|
||||
tensor_group[key] = []
|
||||
tensor_group[key].extend(dsets)
|
||||
|
|
@ -60,7 +62,7 @@ class Checkpoint:
|
|||
self,
|
||||
save_dir: str,
|
||||
) -> None:
|
||||
|
||||
|
||||
save_path = Path(save_dir)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
@ -72,7 +74,7 @@ class Checkpoint:
|
|||
}
|
||||
with open(save_path / "meta.json", "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
|
||||
st.save_file(self.state_dict, save_path / f"state_dict.safetensors")
|
||||
|
||||
@classmethod
|
||||
|
|
@ -83,7 +85,7 @@ class Checkpoint:
|
|||
|
||||
rank = get_rank()
|
||||
save_path = Path(save_dir)
|
||||
|
||||
|
||||
meta = {}
|
||||
if rank == 0:
|
||||
with open(Path(save_dir) / "meta.json", "r") as f:
|
||||
|
|
@ -100,4 +102,4 @@ class Checkpoint:
|
|||
state_dict=state_dict,
|
||||
epoch=meta["epoch"],
|
||||
iteration=meta["iteration"],
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,34 +9,46 @@ class BpeTokenizer:
|
|||
def __init__(self, path=None):
|
||||
self._control_tokens = ["<bos>", "<eos>", "<pad>"]
|
||||
self._special_tokens = ["<|im_start|>", "<|im_end|>"]
|
||||
|
||||
|
||||
model = BPE()
|
||||
self._tokenizer = Tokenizer(model)
|
||||
self._tokenizer.normalizer = normalizers.Sequence([
|
||||
normalizers.NFC(),
|
||||
normalizers.Strip()
|
||||
])
|
||||
|
||||
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
||||
pre_tokenizers.UnicodeScripts(),
|
||||
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True)
|
||||
])
|
||||
|
||||
self._tokenizer.normalizer = normalizers.Sequence(
|
||||
[normalizers.NFC(), normalizers.Strip()]
|
||||
)
|
||||
|
||||
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||
[
|
||||
pre_tokenizers.UnicodeScripts(),
|
||||
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True),
|
||||
]
|
||||
)
|
||||
|
||||
self._tokenizer.decoder = decoders.ByteLevel()
|
||||
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
|
||||
|
||||
|
||||
if path is not None:
|
||||
self._tokenizer = Tokenizer.from_file(path)
|
||||
|
||||
def _prepare_trainer(self, vocab_size: int, min_freq: int, reserved_token_size: int, max_token_length=18) -> tuple:
|
||||
|
||||
def _prepare_trainer(
|
||||
self,
|
||||
vocab_size: int,
|
||||
min_freq: int,
|
||||
reserved_token_size: int,
|
||||
max_token_length=18,
|
||||
) -> tuple:
|
||||
assert reserved_token_size > len(self._special_tokens)
|
||||
reserved_tokens = [f"<|reserve{i:02d}|>" for i in range(reserved_token_size - len(self._special_tokens))]
|
||||
detail_vocab_size = vocab_size - (len(reserved_tokens) + len(self._special_tokens))
|
||||
|
||||
reserved_tokens = [
|
||||
f"<|reserve{i:02d}|>"
|
||||
for i in range(reserved_token_size - len(self._special_tokens))
|
||||
]
|
||||
detail_vocab_size = vocab_size - (
|
||||
len(reserved_tokens) + len(self._special_tokens)
|
||||
)
|
||||
|
||||
alphabet = pre_tokenizers.ByteLevel.alphabet()
|
||||
min_size = len(alphabet) + len(self._control_tokens)
|
||||
assert detail_vocab_size > min_size
|
||||
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=detail_vocab_size,
|
||||
min_frequency=min_freq,
|
||||
|
|
@ -46,61 +58,74 @@ class BpeTokenizer:
|
|||
initial_alphabet=alphabet,
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
|
||||
return trainer, detail_vocab_size, reserved_tokens
|
||||
|
||||
def train(self, files, vocab_size, min_freq, reserved_token_size=100):
|
||||
trainer, _, reserved_tokens = self._prepare_trainer(
|
||||
vocab_size=vocab_size,
|
||||
min_freq=min_freq,
|
||||
reserved_token_size=reserved_token_size
|
||||
reserved_token_size=reserved_token_size,
|
||||
)
|
||||
self._tokenizer.train(files=files, trainer=trainer)
|
||||
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
|
||||
|
||||
def train_from_iterator(self, iterator, vocab_size, min_freq, reserved_token_size=100):
|
||||
|
||||
def train_from_iterator(
|
||||
self, iterator, vocab_size, min_freq, reserved_token_size=100
|
||||
):
|
||||
trainer, _, reserved_tokens = self._prepare_trainer(
|
||||
vocab_size=vocab_size,
|
||||
min_freq=min_freq,
|
||||
reserved_token_size=reserved_token_size
|
||||
reserved_token_size=reserved_token_size,
|
||||
)
|
||||
self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer)
|
||||
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
|
||||
|
||||
|
||||
def save(self, path):
|
||||
self._tokenizer.save(path)
|
||||
|
||||
|
||||
def load(self, path):
|
||||
self._tokenizer = Tokenizer.from_file(path)
|
||||
|
||||
def encode(self, tokens: Union[str, List[str]], out_ids: bool=True, add_special_tokens: bool=False) -> List:
|
||||
def encode(
|
||||
self,
|
||||
tokens: Union[str, List[str]],
|
||||
out_ids: bool = True,
|
||||
add_special_tokens: bool = False,
|
||||
) -> List:
|
||||
if isinstance(tokens, str):
|
||||
encoded: Encoding = self._tokenizer.encode(tokens, add_special_tokens=add_special_tokens)
|
||||
encoded: Encoding = self._tokenizer.encode(
|
||||
tokens, add_special_tokens=add_special_tokens
|
||||
)
|
||||
return encoded.ids if out_ids else encoded.tokens
|
||||
elif isinstance(tokens, list):
|
||||
encoded_list: List[Encoding] = self._tokenizer.encode_batch(tokens, add_special_tokens=add_special_tokens)
|
||||
return [encoded.ids if out_ids else encoded.tokens for encoded in encoded_list]
|
||||
encoded_list: List[Encoding] = self._tokenizer.encode_batch(
|
||||
tokens, add_special_tokens=add_special_tokens
|
||||
)
|
||||
return [
|
||||
encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
|
||||
]
|
||||
|
||||
def decode(self, tokens: List[int], skip_special_tokens: bool=True) -> str:
|
||||
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
|
||||
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._tokenizer.get_vocab_size()
|
||||
|
||||
|
||||
@property
|
||||
def stop_ids(self) -> List[int]:
|
||||
stop_token = self._control_tokens + self._special_tokens
|
||||
stop_ids = [self._tokenizer.token_to_id(token) for token in stop_token]
|
||||
return stop_ids
|
||||
|
||||
|
||||
@property
|
||||
def bos_id(self) -> int:
|
||||
return self._tokenizer.token_to_id("<bos>")
|
||||
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self._tokenizer.token_to_id("<eos>")
|
||||
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
return self._tokenizer.token_to_id("<pad>")
|
||||
return self._tokenizer.token_to_id("<pad>")
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from khaosz.inference.generator import (
|
|||
StreamGenerator,
|
||||
BatchGenerator,
|
||||
EmbeddingEncoder,
|
||||
GeneratorFactory
|
||||
GeneratorFactory,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -19,11 +19,10 @@ __all__ = [
|
|||
"GeneratorCore",
|
||||
"EmbeddingEncoderCore",
|
||||
"KVCacheManager",
|
||||
|
||||
"GenerationRequest",
|
||||
"LoopGenerator",
|
||||
"StreamGenerator",
|
||||
"BatchGenerator",
|
||||
"EmbeddingEncoder",
|
||||
"GeneratorFactory"
|
||||
]
|
||||
"GeneratorFactory",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch import Tensor
|
||||
from torch import Tensor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, List, Tuple, Union, Optional, Self
|
||||
from khaosz.config import ModelParameter, ModelConfig
|
||||
|
|
@ -12,58 +12,61 @@ def apply_sampling_strategies(
|
|||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
filter_value: float = -float("inf")
|
||||
filter_value: float = -float("inf"),
|
||||
) -> Tensor:
|
||||
"""
|
||||
"""
|
||||
Apply sampling strategies to the logits tensor.
|
||||
|
||||
|
||||
Args:
|
||||
logits (Tensor): The logits tensor.
|
||||
temperature (float): The temperature parameter.
|
||||
top_k (int): The top-k parameter.
|
||||
top_p (float): The top-p parameter.
|
||||
filter_value (float, optional): The filter value. Defaults to -float("inf").
|
||||
|
||||
|
||||
Returns:
|
||||
Tensor: The sampled logits tensor.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
|
||||
if top_k > 0:
|
||||
top_k = min(top_k, logits.size(-1))
|
||||
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
|
||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||
indices_to_remove.scatter_(
|
||||
dim=1,
|
||||
index=sorted_indices,
|
||||
src=sorted_indices_to_remove
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
|
||||
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_random_init():
|
||||
init_functions = [
|
||||
'xavier_normal_', 'xavier_uniform_',
|
||||
'kaiming_normal_', 'kaiming_uniform_',
|
||||
'zeros_', 'ones_', 'constant_',
|
||||
'normal_', 'uniform_'
|
||||
"xavier_normal_",
|
||||
"xavier_uniform_",
|
||||
"kaiming_normal_",
|
||||
"kaiming_uniform_",
|
||||
"zeros_",
|
||||
"ones_",
|
||||
"constant_",
|
||||
"normal_",
|
||||
"uniform_",
|
||||
]
|
||||
original_funcs = {}
|
||||
for name in init_functions:
|
||||
|
|
@ -82,7 +85,7 @@ class GeneratorCore:
|
|||
self.model = parameter.model
|
||||
self.tokenizer = parameter.tokenizer
|
||||
self.config = parameter.config
|
||||
|
||||
|
||||
def generate_iterator(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
|
|
@ -91,18 +94,18 @@ class GeneratorCore:
|
|||
top_p: float,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
||||
start_pos: int = 0
|
||||
)-> Tuple[Tensor, int]:
|
||||
|
||||
start_pos: int = 0,
|
||||
) -> Tuple[Tensor, int]:
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
|
||||
logits = outputs["logits"][:, -1, :]
|
||||
cache_increase = input_ids.size(-1)
|
||||
|
||||
cache_increase = input_ids.size(-1)
|
||||
|
||||
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
|
||||
return next_token_id, cache_increase
|
||||
|
||||
def generate_loop(
|
||||
|
|
@ -115,14 +118,21 @@ class GeneratorCore:
|
|||
attn_mask: Optional[Tensor] = None,
|
||||
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
||||
start_pos: int = 0,
|
||||
callback: Optional[Callable[..., Any]] = None
|
||||
callback: Optional[Callable[..., Any]] = None,
|
||||
) -> List[int]:
|
||||
cur_cache_pos = start_pos
|
||||
|
||||
|
||||
for _ in range(len(ids), self.config.max_len):
|
||||
next_token_id, cache_increase = self.generate_iterator(
|
||||
input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos)
|
||||
|
||||
input_ids,
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
attn_mask,
|
||||
kv_caches,
|
||||
cur_cache_pos,
|
||||
)
|
||||
|
||||
input_ids = next_token_id
|
||||
ids.append(next_token_id.item())
|
||||
cur_cache_pos += cache_increase
|
||||
|
|
@ -132,9 +142,9 @@ class GeneratorCore:
|
|||
|
||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||
break
|
||||
|
||||
|
||||
return ids
|
||||
|
||||
|
||||
def to(self, *args, **kargs) -> Self:
|
||||
self.model.to(*args, **kargs)
|
||||
return self
|
||||
|
|
@ -145,32 +155,35 @@ class EmbeddingEncoderCore:
|
|||
self.model = parameter.model
|
||||
self.tokenizer = parameter.tokenizer
|
||||
self.config = parameter.config
|
||||
|
||||
|
||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||
with_batch = isinstance(sentence, list)
|
||||
ids = self.tokenizer.encode(sentence)
|
||||
batch_ids = ids if with_batch else [ids]
|
||||
max_model_len = self.config.max_len
|
||||
|
||||
|
||||
all_fragments = []
|
||||
fragment_origin_idx = []
|
||||
|
||||
|
||||
for i, seq in enumerate(batch_ids):
|
||||
if len(seq) > max_model_len:
|
||||
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
|
||||
fragments = [
|
||||
seq[j : j + max_model_len]
|
||||
for j in range(0, len(seq), max_model_len)
|
||||
]
|
||||
all_fragments.extend(fragments)
|
||||
fragment_origin_idx.extend([i] * len(fragments))
|
||||
else:
|
||||
all_fragments.append(seq)
|
||||
fragment_origin_idx.append(i)
|
||||
|
||||
#if empty fragments
|
||||
|
||||
# if empty fragments
|
||||
if not all_fragments or not ids:
|
||||
return [] if with_batch else torch.tensor([])
|
||||
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
|
||||
|
||||
|
||||
padded_ids = []
|
||||
masks = []
|
||||
for seq in all_fragments:
|
||||
|
|
@ -179,24 +192,30 @@ class EmbeddingEncoderCore:
|
|||
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
|
||||
padded_ids.append(padded_seq)
|
||||
masks.append(mask)
|
||||
|
||||
|
||||
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
|
||||
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
|
||||
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
|
||||
# [num_fragments, seq_len, hidden_size]
|
||||
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
|
||||
|
||||
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
|
||||
|
||||
sentence_embs: List[Tensor] = []
|
||||
for i in range(len(batch_ids)):
|
||||
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
|
||||
indices = [
|
||||
idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i
|
||||
]
|
||||
if indices:
|
||||
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
|
||||
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
|
||||
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
|
||||
sum_frags = torch.sum(
|
||||
fragment_embs[indices, :, :], dim=1
|
||||
) # [frags, hidden_size]
|
||||
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(
|
||||
1
|
||||
) # [frags, 1]
|
||||
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
|
||||
sentence_embs.append(emb.flatten())
|
||||
|
||||
|
||||
if with_batch:
|
||||
return [emb.flatten() for emb in sentence_embs]
|
||||
else:
|
||||
|
|
@ -209,11 +228,11 @@ class EmbeddingEncoderCore:
|
|||
|
||||
class KVCacheManager:
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelConfig,
|
||||
self,
|
||||
config: ModelConfig,
|
||||
batch_size: int,
|
||||
device: torch.device = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16
|
||||
device: torch.device = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.device = device
|
||||
|
|
@ -221,25 +240,41 @@ class KVCacheManager:
|
|||
self.num_layers = config.n_layers
|
||||
self.max_len = config.max_len
|
||||
self.num_heads = config.n_kv_heads
|
||||
self.head_dim = config.dim //config.n_heads
|
||||
|
||||
self.head_dim = config.dim // config.n_heads
|
||||
|
||||
self._kv_cache: Tuple[Tensor, Tensor] = None
|
||||
self._seq_mask: Tensor = None
|
||||
self._initialize()
|
||||
|
||||
|
||||
def _initialize(self):
|
||||
k_cache = torch.empty(
|
||||
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
(
|
||||
self.batch_size,
|
||||
self.max_len,
|
||||
self.num_layers,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
),
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
v_cache = torch.empty(
|
||||
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
(
|
||||
self.batch_size,
|
||||
self.max_len,
|
||||
self.num_layers,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
),
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self._kv_cache = (k_cache, v_cache)
|
||||
self._seq_mask = torch.ones((self.batch_size, self.max_len), device=self.device, dtype=torch.bool)
|
||||
self._seq_mask = torch.ones(
|
||||
(self.batch_size, self.max_len), device=self.device, dtype=torch.bool
|
||||
)
|
||||
|
||||
def update(self, active_mask: Tensor):
|
||||
def update(self, active_mask: Tensor):
|
||||
k_cache, v_cache = self._kv_cache
|
||||
self._kv_cache = (k_cache[active_mask], v_cache[active_mask])
|
||||
self._seq_mask = self._seq_mask[active_mask]
|
||||
|
|
@ -250,14 +285,14 @@ class KVCacheManager:
|
|||
self._seq_mask = None
|
||||
else:
|
||||
self._initialize()
|
||||
|
||||
|
||||
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
|
||||
batch_size, seq_len = input_ids.shape
|
||||
bool_mask = (input_ids != pad_id)
|
||||
self._seq_mask[: batch_size, : seq_len] = bool_mask
|
||||
bool_mask = input_ids != pad_id
|
||||
self._seq_mask[:batch_size, :seq_len] = bool_mask
|
||||
|
||||
def get_kvcache(self) -> Tuple[Tensor, Tensor]:
|
||||
return self._kv_cache
|
||||
|
||||
def get_seq_mask(self) -> Tensor:
|
||||
return self._seq_mask
|
||||
|
||||
def get_seq_mask(self) -> Tensor:
|
||||
return self._seq_mask
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
from dataclasses import dataclass
|
||||
from torch import Tensor
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple, Union, Optional, Generator
|
||||
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
|
|
@ -8,10 +8,11 @@ from khaosz.config.param_config import ModelParameter
|
|||
|
||||
HistoryType = List[Tuple[str, str]]
|
||||
|
||||
|
||||
def build_prompt(
|
||||
query: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
history: Optional[HistoryType] = None
|
||||
history: Optional[HistoryType] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build prompt in ChatML format for query and history.
|
||||
|
|
@ -42,17 +43,17 @@ def build_prompt(
|
|||
|
||||
|
||||
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
|
||||
"""
|
||||
"""
|
||||
Pad a list of sequences to a fixed length.
|
||||
|
||||
|
||||
Args:
|
||||
ids_list (List[List[int]]): A list of sequences.
|
||||
max_ids_len (int): The maximum length of sequences.
|
||||
pad_id (int): The id to pad sequences.
|
||||
|
||||
|
||||
Returns:
|
||||
List[List[int]]: A list of padded sequences.
|
||||
|
||||
|
||||
"""
|
||||
max_ids_len = max(len(ids) for ids in ids_list)
|
||||
new_ids_list = []
|
||||
|
|
@ -60,7 +61,7 @@ def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]
|
|||
pad_len = max_ids_len - len(ids)
|
||||
padded_seq = [pad_id] * pad_len + ids
|
||||
new_ids_list.append(padded_seq)
|
||||
|
||||
|
||||
return new_ids_list, max_ids_len
|
||||
|
||||
|
||||
|
|
@ -68,7 +69,7 @@ def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]
|
|||
class GenerationRequest:
|
||||
"""
|
||||
Request parameters for text generation.
|
||||
|
||||
|
||||
Attributes:
|
||||
top_k: Top-k sampling parameter.
|
||||
top_p: Top-p (nucleus) sampling parameter.
|
||||
|
|
@ -79,6 +80,7 @@ class GenerationRequest:
|
|||
system_prompt: System prompt for the conversation.
|
||||
stream: Whether to use streaming generation.
|
||||
"""
|
||||
|
||||
top_k: int
|
||||
top_p: float
|
||||
temperature: float
|
||||
|
|
@ -101,63 +103,66 @@ class GenerationRequest:
|
|||
class LoopGenerator(GeneratorCore):
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
super().__init__(parameter)
|
||||
|
||||
|
||||
def generate(self, request: GenerationRequest) -> str:
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||
|
||||
|
||||
prompt = build_prompt(request.query, request.history)
|
||||
ids = self.tokenizer.encode(prompt)
|
||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||
|
||||
|
||||
start_cache_pos = len(ids)
|
||||
self.model.eval()
|
||||
kv_caches = cache_manager.get_kvcache()
|
||||
|
||||
|
||||
ids = self.generate_loop(
|
||||
input_ids,
|
||||
ids,
|
||||
request.temperature,
|
||||
request.top_k,
|
||||
request.top_p,
|
||||
kv_caches=kv_caches,
|
||||
input_ids,
|
||||
ids,
|
||||
request.temperature,
|
||||
request.top_k,
|
||||
request.top_p,
|
||||
kv_caches=kv_caches,
|
||||
)
|
||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class StreamGenerator(GeneratorCore):
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
super().__init__(parameter)
|
||||
|
||||
def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
|
||||
|
||||
def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||
|
||||
prompt = build_prompt(request.query, request.history)
|
||||
ids = self.tokenizer.encode(prompt)
|
||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||
|
||||
|
||||
start_cache_pos = len(ids)
|
||||
cur_cache_pos = 0
|
||||
self.model.eval()
|
||||
kv_caches = cache_manager.get_kvcache()
|
||||
|
||||
|
||||
for _ in range(len(ids), self.config.max_len):
|
||||
next_token_id, cache_increase = self.generate_iterator(
|
||||
input_ids, request.temperature, request.top_k, request.top_p,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
input_ids,
|
||||
request.temperature,
|
||||
request.top_k,
|
||||
request.top_p,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos,
|
||||
)
|
||||
|
||||
|
||||
input_ids = next_token_id
|
||||
ids.append(next_token_id.item())
|
||||
cur_cache_pos += cache_increase
|
||||
|
||||
|
||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||
yield response
|
||||
|
||||
|
||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||
yield response + "\n"
|
||||
break
|
||||
|
|
@ -166,131 +171,140 @@ class StreamGenerator(GeneratorCore):
|
|||
class BatchGenerator(GeneratorCore):
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
super().__init__(parameter)
|
||||
|
||||
|
||||
def generate(self, request: GenerationRequest) -> List[str]:
|
||||
batch_size = len(request.query)
|
||||
if request.history is None:
|
||||
request.history = [[] for _ in range(batch_size)]
|
||||
|
||||
prompts = [build_prompt(query, history) for query, history in zip(request.query, request.history)]
|
||||
|
||||
|
||||
prompts = [
|
||||
build_prompt(query, history)
|
||||
for query, history in zip(request.query, request.history)
|
||||
]
|
||||
|
||||
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
|
||||
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
|
||||
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(self.config, batch_size, device=device)
|
||||
|
||||
|
||||
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
|
||||
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
|
||||
activate_task_mask = [True] * batch_size
|
||||
|
||||
|
||||
start_cache_pos = max_ids_len
|
||||
cur_cache_pos = 0
|
||||
|
||||
|
||||
while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0:
|
||||
kv_caches = cache_manager.get_kvcache()
|
||||
attn_mask =cache_manager.get_seq_mask()
|
||||
|
||||
attn_mask = cache_manager.get_seq_mask()
|
||||
|
||||
next_token_id, cache_increase = self.generate_iterator(
|
||||
input_tensor, request.temperature, request.top_k, request.top_p,
|
||||
attn_mask=attn_mask,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
input_tensor,
|
||||
request.temperature,
|
||||
request.top_k,
|
||||
request.top_p,
|
||||
attn_mask=attn_mask,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos,
|
||||
)
|
||||
|
||||
|
||||
cur_cache_pos += cache_increase
|
||||
active_mask = []
|
||||
c_ids = 0
|
||||
|
||||
|
||||
for i in range(batch_size):
|
||||
if activate_task_mask[i]:
|
||||
token = next_token_id[c_ids, :].item()
|
||||
ids_list[i].append(token)
|
||||
c_ids += 1
|
||||
|
||||
|
||||
is_active = not token in self.tokenizer.stop_ids
|
||||
activate_task_mask[i] = is_active
|
||||
activate_task_mask[i] = is_active
|
||||
active_mask.append(is_active)
|
||||
|
||||
|
||||
active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool)
|
||||
cache_manager.update(active_mask)
|
||||
input_tensor = next_token_id[active_mask, :]
|
||||
|
||||
max_ids_len += 1
|
||||
|
||||
|
||||
responses = [str()] * batch_size
|
||||
for i in range(batch_size):
|
||||
responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
|
||||
request.history[i].append((request.query[i], responses[i]))
|
||||
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
class EmbeddingEncoder(EmbeddingEncoderCore):
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
super().__init__(parameter)
|
||||
|
||||
|
||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||
return super().encode(sentence)
|
||||
|
||||
|
||||
class GeneratorFactory:
|
||||
"""Factory class for creating generator instances.
|
||||
|
||||
|
||||
Provides smart generator selection based on request characteristics:
|
||||
- Streaming: Use StreamGenerator for streaming output
|
||||
- Batch: Use BatchGenerator when query is a list
|
||||
- Single: Use LoopGenerator for single query non-streaming
|
||||
|
||||
|
||||
Example usage:
|
||||
generator = GeneratorFactory.create_generator(parameter, request)
|
||||
result = generator.generate(request)
|
||||
"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_generator(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
|
||||
def create_generator(
|
||||
parameter: ModelParameter, request: GenerationRequest
|
||||
) -> GeneratorCore:
|
||||
"""Create a generator based on request characteristics.
|
||||
|
||||
|
||||
Args:
|
||||
parameter: Model parameters containing model, tokenizer, config
|
||||
request: Generation request with query, options, etc.
|
||||
|
||||
|
||||
Returns:
|
||||
Appropriate GeneratorCore subclass instance
|
||||
"""
|
||||
# Streaming generation: check stream field first
|
||||
if request.stream:
|
||||
return StreamGenerator(parameter)
|
||||
|
||||
|
||||
# Batch generation: query is a list of strings
|
||||
if isinstance(request.query, list):
|
||||
return BatchGenerator(parameter)
|
||||
|
||||
|
||||
# Default: single query non-streaming
|
||||
return LoopGenerator(parameter)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_encoder(parameter: ModelParameter) -> EmbeddingEncoderCore:
|
||||
"""Create an embedding encoder instance.
|
||||
|
||||
|
||||
Args:
|
||||
parameter: Model parameters
|
||||
|
||||
|
||||
Returns:
|
||||
EmbeddingEncoderCore instance
|
||||
"""
|
||||
return EmbeddingEncoder(parameter)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
|
||||
def create(
|
||||
cls, parameter: ModelParameter, request: GenerationRequest
|
||||
) -> GeneratorCore:
|
||||
"""Convenience method that delegates to create_generator.
|
||||
|
||||
|
||||
Args:
|
||||
parameter: Model parameters
|
||||
request: Generation request
|
||||
|
||||
|
||||
Returns:
|
||||
Generator instance
|
||||
"""
|
||||
return cls.create_generator(parameter, request)
|
||||
|
||||
|
|
@ -1,17 +1,10 @@
|
|||
from khaosz.model.module import (
|
||||
from khaosz.model.module import (
|
||||
Linear,
|
||||
RMSNorm,
|
||||
RMSNorm,
|
||||
MLP,
|
||||
GQA,
|
||||
DecoderBlock,
|
||||
)
|
||||
from khaosz.model.transformer import Transformer
|
||||
|
||||
__all__ = [
|
||||
"Linear",
|
||||
"RMSNorm",
|
||||
"MLP",
|
||||
"GQA",
|
||||
"DecoderBlock",
|
||||
"Transformer"
|
||||
]
|
||||
__all__ = ["Linear", "RMSNorm", "MLP", "GQA", "DecoderBlock", "Transformer"]
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from typing import Optional, Tuple
|
|||
|
||||
|
||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||
"""
|
||||
"""
|
||||
Repeat k times along the dimension for attention heads.
|
||||
Args:
|
||||
x (Tensor): The input tensor.
|
||||
|
|
@ -15,7 +15,7 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
|||
Returns:
|
||||
Tensor: The repeated tensor.
|
||||
"""
|
||||
|
||||
|
||||
bs, slen, n_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
|
|
@ -25,12 +25,13 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
|||
.reshape(bs, slen, n_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
|
||||
def get_rotary_emb(
|
||||
dim: int,
|
||||
max_len: int,
|
||||
base: float = 10000,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
dim: int,
|
||||
max_len: int,
|
||||
base: float = 10000,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get the rotary embedding for the given dimension and maximum length.
|
||||
Args:
|
||||
dim (int): The dimension of the input.
|
||||
|
|
@ -46,6 +47,7 @@ def get_rotary_emb(
|
|||
|
||||
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
|
||||
"""
|
||||
Apply rotary embedding to the input tensor using cos/sin form.
|
||||
|
|
@ -55,49 +57,49 @@ def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tens
|
|||
Returns:
|
||||
Tensor: The output tensor (rotated, same shape as input).
|
||||
"""
|
||||
|
||||
|
||||
dtype = x.dtype
|
||||
cos, sin = rotary_emb
|
||||
|
||||
|
||||
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
|
||||
sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
|
||||
|
||||
|
||||
x_real = x[..., 0::2] # [batch, seq_len, dim//2]
|
||||
x_imag = x[..., 1::2] # [batch, seq_len, dim//2]
|
||||
|
||||
|
||||
x_real_rot = x_real * cos - x_imag * sin
|
||||
x_imag_rot = x_real * sin + x_imag * cos
|
||||
|
||||
|
||||
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
|
||||
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
|
||||
|
||||
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
|
||||
|
||||
return x_out.to(dtype)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, max_len: int, base: int=10000):
|
||||
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_len = max_len
|
||||
self.base = base
|
||||
self.max_len_cached = None
|
||||
self._set_rotary_buffer(self.max_len)
|
||||
|
||||
|
||||
def _set_rotary_buffer(self, max_len: int):
|
||||
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base)
|
||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
self.max_len_cached = max_len
|
||||
|
||||
def forward(self, x: Tensor, start_pos: int=0) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
|
||||
seq_len = x.size(1)
|
||||
|
||||
|
||||
if self.max_len_cached < seq_len + start_pos:
|
||||
self._set_rotary_buffer(seq_len + start_pos)
|
||||
|
||||
|
||||
cos = self.cos_cached[start_pos : start_pos + seq_len]
|
||||
sin = self.sin_cached[start_pos : start_pos + seq_len]
|
||||
|
||||
|
||||
return (cos, sin)
|
||||
|
||||
|
||||
|
|
@ -115,43 +117,42 @@ class RMSNorm(nn.Module):
|
|||
def __init__(self, dim, norm_eps):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.normalized_shape = (dim, )
|
||||
self.normalized_shape = (dim,)
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps)
|
||||
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps)
|
||||
return rms.to(x.dtype)
|
||||
|
||||
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim: int, dim_feed_forward: int):
|
||||
super().__init__()
|
||||
self.up = Linear(dim, dim_feed_forward)
|
||||
self.gate = Linear(dim, dim_feed_forward)
|
||||
self.down = Linear(dim_feed_forward, dim)
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
gated = self.up(x) * F.silu(self.gate(x))
|
||||
out = self.down(gated)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
class GQA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
use_qk_norm: bool,
|
||||
norm_eps: float,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int
|
||||
layer_id: int,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % n_heads == 0
|
||||
assert n_heads % n_kv_heads == 0
|
||||
|
||||
|
||||
self.head_dim = dim // n_heads
|
||||
self.layer_id = layer_id
|
||||
self.dim = dim
|
||||
|
|
@ -165,11 +166,11 @@ class GQA(nn.Module):
|
|||
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
|
||||
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
|
||||
self.o_proj = Linear(dim, dim)
|
||||
|
||||
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = RMSNorm(self.head_dim, norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, norm_eps)
|
||||
|
||||
|
||||
if self.use_gated_attention:
|
||||
self.gate = Linear(dim, dim)
|
||||
|
||||
|
|
@ -177,14 +178,14 @@ class GQA(nn.Module):
|
|||
batch_size, seq_len, _ = x.shape
|
||||
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
|
||||
return x
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
x: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
mask: Tensor = None,
|
||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||
start_pos: int = 0
|
||||
start_pos: int = 0,
|
||||
) -> Tensor:
|
||||
bsz, seq_len, _ = x.size()
|
||||
is_causal = mask is None
|
||||
|
|
@ -194,31 +195,36 @@ class GQA(nn.Module):
|
|||
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
||||
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
||||
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
|
||||
|
||||
|
||||
if self.use_qk_norm:
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
|
||||
if kv_cache is not None:
|
||||
k_cache, v_cache = kv_cache
|
||||
|
||||
|
||||
# copy to cache
|
||||
k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k
|
||||
v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v
|
||||
|
||||
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
|
||||
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
|
||||
|
||||
# get cache
|
||||
k = k_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
||||
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
||||
|
||||
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
||||
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
||||
|
||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||
|
||||
|
||||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
|
||||
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal).permute(0, 2, 1, 3).contiguous().flatten(2)
|
||||
|
||||
sdqa_out = (
|
||||
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
||||
.permute(0, 2, 1, 3)
|
||||
.contiguous()
|
||||
.flatten(2)
|
||||
)
|
||||
|
||||
if self.use_gated_attention:
|
||||
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
|
||||
|
||||
|
||||
out = self.o_proj(sdqa_out)
|
||||
|
||||
return out
|
||||
|
|
@ -227,15 +233,15 @@ class GQA(nn.Module):
|
|||
class MLA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
norm_eps: float,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int
|
||||
layer_id: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
@ -252,45 +258,46 @@ class MLA(nn.Module):
|
|||
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
||||
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||
self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps)
|
||||
|
||||
|
||||
# KV (k_nope, k_rope, v)
|
||||
self.kv_b_proj = Linear(
|
||||
kv_lora_rank,
|
||||
kv_lora_rank,
|
||||
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
||||
)
|
||||
|
||||
|
||||
self.o_proj = Linear(dim, dim, bias=False)
|
||||
|
||||
|
||||
if use_gated_attention:
|
||||
self.gate = Linear(dim, dim, bias=False)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
mask: Tensor = None,
|
||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||
start_pos: int = 0
|
||||
start_pos: int = 0,
|
||||
) -> Tensor:
|
||||
bsz, seq_len, _ = x.size()
|
||||
is_causal = mask is None
|
||||
|
||||
|
||||
q = self.q_proj(x)
|
||||
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
|
||||
|
||||
|
||||
kv_compressed = self.kv_a_proj(x)
|
||||
kv_compressed = self.kv_norm(kv_compressed)
|
||||
|
||||
|
||||
kv = self.kv_b_proj(kv_compressed)
|
||||
kv = kv.view(bsz, seq_len, self.n_kv_heads, -1)
|
||||
|
||||
|
||||
k_nope, k_rope, v = torch.split(
|
||||
kv,
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim],
|
||||
dim=-1
|
||||
kv, [self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], dim=-1
|
||||
)
|
||||
|
||||
q_nope, q_rope = (
|
||||
q[..., : self.qk_nope_head_dim],
|
||||
q[..., self.qk_rope_head_dim :],
|
||||
)
|
||||
|
||||
q_nope, q_rope = q[..., :self.qk_nope_head_dim], q[..., self.qk_rope_head_dim:]
|
||||
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
||||
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
||||
|
||||
|
|
@ -299,41 +306,48 @@ class MLA(nn.Module):
|
|||
|
||||
if kv_cache is not None:
|
||||
k_cache, v_cache = kv_cache
|
||||
k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k
|
||||
v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v
|
||||
k = k_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
||||
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
||||
|
||||
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
|
||||
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
|
||||
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
||||
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
||||
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
|
||||
attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
||||
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
|
||||
|
||||
|
||||
if self.use_gated_attention:
|
||||
attn_out = attn_out * F.sigmoid(self.gate(x))
|
||||
|
||||
out = self.o_proj(attn_out)
|
||||
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
dim_ffn: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: int,
|
||||
use_qk_norm: bool,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
dim_ffn: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: int,
|
||||
use_qk_norm: bool,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = GQA(dim, n_heads, n_kv_heads,
|
||||
use_qk_norm, norm_eps, use_gated_attention, layer_id)
|
||||
self.attention = GQA(
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
use_qk_norm,
|
||||
norm_eps,
|
||||
use_gated_attention,
|
||||
layer_id,
|
||||
)
|
||||
self.input_norm = RMSNorm(dim, norm_eps)
|
||||
self.mlp = MLP(dim, dim_ffn)
|
||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||
|
|
@ -341,24 +355,20 @@ class DecoderBlock(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||
start_pos: int = 0
|
||||
start_pos: int = 0,
|
||||
) -> Tensor:
|
||||
# attention
|
||||
attn_output = self.attention(
|
||||
self.input_norm(x),
|
||||
rotary_emb,
|
||||
attention_mask,
|
||||
kv_cache,
|
||||
start_pos
|
||||
self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos
|
||||
)
|
||||
x = attn_output + x
|
||||
|
||||
|
||||
# feed forward
|
||||
x = self.mlp(self.post_attention_norm(x)) + x
|
||||
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
|
@ -366,6 +376,6 @@ class Embedding(nn.Module):
|
|||
def __init__(self, vocab_size: int, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.embedding(x, self.weight)
|
||||
return F.embedding(x, self.weight)
|
||||
|
|
|
|||
|
|
@ -4,15 +4,21 @@ import torch.nn as nn
|
|||
from torch import Tensor
|
||||
from typing import Any, Mapping, Optional, Tuple
|
||||
from khaosz.config.model_config import ModelConfig
|
||||
from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, RotaryEmbedding
|
||||
from khaosz.model.module import (
|
||||
Embedding,
|
||||
DecoderBlock,
|
||||
Linear,
|
||||
RMSNorm,
|
||||
RotaryEmbedding,
|
||||
)
|
||||
|
||||
|
||||
def process_attention_mask(
|
||||
seq_mask: Tensor,
|
||||
input_tensor: Tensor,
|
||||
start_pos: int = 0,
|
||||
is_causal: bool = False,
|
||||
) -> Tensor:
|
||||
seq_mask: Tensor,
|
||||
input_tensor: Tensor,
|
||||
start_pos: int = 0,
|
||||
is_causal: bool = False,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Create attention mask for GQA
|
||||
Args:
|
||||
|
|
@ -26,32 +32,36 @@ def process_attention_mask(
|
|||
device = input_tensor.device
|
||||
dtype = input_tensor.dtype
|
||||
seq_len = input_tensor.size(1)
|
||||
|
||||
|
||||
if seq_mask is None:
|
||||
if start_pos != 0:
|
||||
# for single prompt chat
|
||||
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
if seq_mask.dim() > 2:
|
||||
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
|
||||
# if ndim > 2, it's 4D tensor
|
||||
return seq_mask
|
||||
|
||||
|
||||
batch_size = seq_mask.size(0)
|
||||
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
|
||||
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
|
||||
# (bsz, start_pos + seq_len)
|
||||
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
|
||||
expanded_mask = seq_mask.unsqueeze(1).expand(
|
||||
batch_size, seq_len, start_pos + seq_len
|
||||
)
|
||||
# (bsz, seq_len, start_pos + seq_len)
|
||||
|
||||
|
||||
if is_causal:
|
||||
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
|
||||
|
||||
|
||||
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
|
||||
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
|
||||
attention_mask = attention_mask.masked_fill_(
|
||||
~expanded_mask, -torch.finfo(dtype).max / 2
|
||||
).unsqueeze(1)
|
||||
# (bsz, 1, seq_len, seq_len + start_pos)
|
||||
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
|
|
@ -59,26 +69,38 @@ class Transformer(nn.Module):
|
|||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rotary_embeding = RotaryEmbedding(config.dim // config.n_heads, config.max_len)
|
||||
self.rotary_embeding = RotaryEmbedding(
|
||||
config.dim // config.n_heads, config.max_len
|
||||
)
|
||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
DecoderBlock(config.dim, config.n_heads, config.dim_ffn, config.n_kv_heads,
|
||||
config.norm_eps, config.use_qk_norm, config.use_gated_attention, layer_id)
|
||||
for layer_id in range(config.n_layers)
|
||||
])
|
||||
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DecoderBlock(
|
||||
config.dim,
|
||||
config.n_heads,
|
||||
config.dim_ffn,
|
||||
config.n_kv_heads,
|
||||
config.norm_eps,
|
||||
config.use_qk_norm,
|
||||
config.use_gated_attention,
|
||||
layer_id,
|
||||
)
|
||||
for layer_id in range(config.n_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||
self.lm_head = Linear(config.dim, config.vocab_size)
|
||||
|
||||
|
||||
if self.config.tie_weight == True:
|
||||
self.lm_head.weight = self.embed_tokens.weight
|
||||
|
||||
self._init_parameters()
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||
lm_head_key = 'lm_head.weight'
|
||||
embed_key = 'embed_tokens.weight'
|
||||
lm_head_key = "lm_head.weight"
|
||||
embed_key = "embed_tokens.weight"
|
||||
|
||||
if self.config.tie_weight == True:
|
||||
# same tensor
|
||||
|
|
@ -87,48 +109,44 @@ class Transformer(nn.Module):
|
|||
if lm_head_key not in state_dict and embed_key in state_dict:
|
||||
# use clone to avoid sharing the same tensor
|
||||
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
|
||||
|
||||
|
||||
return super().load_state_dict(state_dict, strict, assign)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
state_dict = super().state_dict(
|
||||
destination=destination, prefix=prefix, keep_vars=keep_vars
|
||||
)
|
||||
|
||||
if self.config.tie_weight == True:
|
||||
lm_head_key = prefix + 'lm_head.weight'
|
||||
lm_head_key = prefix + "lm_head.weight"
|
||||
if lm_head_key in state_dict:
|
||||
del state_dict[lm_head_key]
|
||||
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def _init_parameters(self):
|
||||
for param in self.parameters():
|
||||
if param.dim() > 1:
|
||||
nn.init.normal_(param, mean=0.0, std=0.006)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
input_mask: Optional[Tensor]=None,
|
||||
persistent_key_values: Optional[Tuple[Tensor, Tensor]]=None,
|
||||
start_pos: int = 0
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
input_mask: Optional[Tensor] = None,
|
||||
persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None,
|
||||
start_pos: int = 0,
|
||||
) -> Tensor:
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
|
||||
x = self.embed_tokens(input_ids)
|
||||
rotary_emb = self.rotary_embeding(x, start_pos)
|
||||
|
||||
attn_mask = process_attention_mask(
|
||||
input_mask, x, start_pos, is_causal=True
|
||||
)
|
||||
|
||||
|
||||
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos)
|
||||
|
||||
hidden_states = self.norm(x)
|
||||
|
||||
hidden_states = self.norm(x)
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
return {
|
||||
"logits": logits,
|
||||
"hidden_states": hidden_states
|
||||
}
|
||||
|
||||
|
||||
return {"logits": logits, "hidden_states": hidden_states}
|
||||
|
|
|
|||
|
|
@ -1,27 +1,21 @@
|
|||
from khaosz.parallel.setup import (
|
||||
get_world_size,
|
||||
get_world_size,
|
||||
get_rank,
|
||||
get_current_device,
|
||||
|
||||
only_on_rank,
|
||||
setup_parallel,
|
||||
spawn_parallel_fn
|
||||
setup_parallel,
|
||||
spawn_parallel_fn,
|
||||
)
|
||||
|
||||
from khaosz.parallel.module import (
|
||||
RowParallelLinear,
|
||||
ColumnParallelLinear
|
||||
)
|
||||
from khaosz.parallel.module import RowParallelLinear, ColumnParallelLinear
|
||||
|
||||
__all__ = [
|
||||
"get_world_size",
|
||||
"get_rank",
|
||||
"get_current_device",
|
||||
|
||||
"only_on_rank",
|
||||
"setup_parallel",
|
||||
"spawn_parallel_fn",
|
||||
|
||||
"RowParallelLinear",
|
||||
"ColumnParallelLinear"
|
||||
"ColumnParallelLinear",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -17,91 +17,99 @@ class ParallelModel(nn.Module):
|
|||
|
||||
class RowParallelLinear(ParallelModel):
|
||||
def __init__(
|
||||
self,
|
||||
self,
|
||||
process_group: dist.ProcessGroup,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
reduce_results: bool = True
|
||||
reduce_results: bool = True,
|
||||
):
|
||||
super().__init__(process_group)
|
||||
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.in_features_per_rank = in_features // self.world_size
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
|
||||
if in_features % self.world_size != 0:
|
||||
raise ValueError(f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}")
|
||||
|
||||
raise ValueError(
|
||||
f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}"
|
||||
)
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank))
|
||||
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
|
||||
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
output = F.linear(input, self.weight)
|
||||
|
||||
|
||||
if self.reduce_results:
|
||||
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
||||
|
||||
|
||||
if self.bias is not None:
|
||||
output += self.bias
|
||||
|
||||
|
||||
return output
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
||||
full_weight = state_dict.get('weight')
|
||||
full_bias = state_dict.get('bias')
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
||||
full_weight = state_dict.get("weight")
|
||||
full_bias = state_dict.get("bias")
|
||||
|
||||
start_idx = self.rank * self.in_features_per_rank
|
||||
end_idx = start_idx + self.in_features_per_rank
|
||||
weight_slice = full_weight[:, start_idx:end_idx]
|
||||
self.weight.data.copy_(weight_slice)
|
||||
|
||||
|
||||
if self.bias is not None:
|
||||
self.bias.data.copy_(full_bias)
|
||||
|
||||
|
||||
class ColumnParallelLinear(ParallelModel):
|
||||
def __init__(
|
||||
self,
|
||||
self,
|
||||
process_group: dist.ProcessGroup,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
gather_results: bool = True
|
||||
gather_results: bool = True,
|
||||
):
|
||||
super().__init__(process_group)
|
||||
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.out_features_per_rank = out_features // self.world_size
|
||||
self.gather_results = gather_results
|
||||
|
||||
if out_features % self.world_size != 0:
|
||||
raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}")
|
||||
raise ValueError(
|
||||
f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}"
|
||||
)
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(self.out_features_per_rank, self.in_features)
|
||||
)
|
||||
self.bias = (
|
||||
nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None
|
||||
)
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(self.out_features_per_rank, self.in_features))
|
||||
self.bias = nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
output = F.linear(input, self.weight, self.bias)
|
||||
|
||||
|
||||
if self.gather_results:
|
||||
output_list = [torch.empty_like(output) for _ in range(self.world_size)]
|
||||
dist.all_gather(output_list, output, group=self.process_group)
|
||||
output = torch.cat(output_list, dim=-1)
|
||||
|
||||
|
||||
return output
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
||||
full_weight = state_dict.get('weight')
|
||||
full_bias = state_dict.get('bias')
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
||||
full_weight = state_dict.get("weight")
|
||||
full_bias = state_dict.get("bias")
|
||||
|
||||
start_idx = self.rank * self.out_features_per_rank
|
||||
end_idx = start_idx + self.out_features_per_rank
|
||||
weight_slice = full_weight[start_idx:end_idx, :]
|
||||
self.weight.data.copy_(weight_slice)
|
||||
|
||||
|
||||
if self.bias is not None:
|
||||
bias_slice = full_bias[start_idx:end_idx]
|
||||
self.bias.data.copy_(bias_slice)
|
||||
self.bias.data.copy_(bias_slice)
|
||||
|
|
|
|||
|
|
@ -11,73 +11,74 @@ from typing import Callable, List, Optional
|
|||
def get_current_device():
|
||||
return os.environ["LOCAL_DEVICE"]
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
return dist.get_world_size()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
return dist.get_rank()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
@contextmanager
|
||||
def setup_parallel(
|
||||
rank: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
backend: str = "nccl",
|
||||
master_addr: str = "localhost",
|
||||
master_port: str = "29500",
|
||||
device_type: str = "cuda",
|
||||
device_ids: Optional[List[int]] = None
|
||||
device_ids: Optional[List[int]] = None,
|
||||
):
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
yield dist.group.WORLD
|
||||
return
|
||||
return
|
||||
|
||||
if world_size <= 1:
|
||||
yield None
|
||||
return
|
||||
|
||||
|
||||
if device_ids is None:
|
||||
device_ids = [i for i in range(world_size)]
|
||||
|
||||
|
||||
rank = device_ids[rank % len(device_ids)]
|
||||
device_id = torch.device(device_type, device_ids[rank])
|
||||
|
||||
os.environ['MASTER_ADDR'] = master_addr
|
||||
os.environ['MASTER_PORT'] = master_port
|
||||
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
|
||||
os.environ["MASTER_ADDR"] = master_addr
|
||||
os.environ["MASTER_PORT"] = master_port
|
||||
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||
|
||||
|
||||
dist.init_process_group(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
backend=backend,
|
||||
device_id=device_id
|
||||
rank=rank, world_size=world_size, backend=backend, device_id=device_id
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if backend == "nccl" and torch.cuda.is_available():
|
||||
torch.cuda.set_device(device_id)
|
||||
elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||
elif backend == "ccl" and hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
torch.xpu.set_device(device_id)
|
||||
|
||||
|
||||
yield dist.group.WORLD
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def only_on_rank(rank, sync=False):
|
||||
"""
|
||||
decorator to run a function only on a specific rank.
|
||||
"""
|
||||
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
|
|
@ -89,67 +90,81 @@ def only_on_rank(rank, sync=False):
|
|||
dist.barrier()
|
||||
|
||||
return ret_args
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def wrapper_spawn_func(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
backend: str,
|
||||
master_addr: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
backend: str,
|
||||
master_addr: str,
|
||||
master_port: str,
|
||||
device_type: str,
|
||||
device_ids: List[int],
|
||||
func: Callable,
|
||||
kwargs: dict
|
||||
device_ids: List[int],
|
||||
func: Callable,
|
||||
kwargs: dict,
|
||||
):
|
||||
try:
|
||||
with setup_parallel(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
backend=backend,
|
||||
master_addr=master_addr,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
backend=backend,
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
device_type=device_type,
|
||||
device_ids=device_ids
|
||||
device_ids=device_ids,
|
||||
):
|
||||
func(**kwargs)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in rank {rank}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def spawn_parallel_fn(
|
||||
func: Callable,
|
||||
world_size: int,
|
||||
func: Callable,
|
||||
world_size: int,
|
||||
backend: str = "nccl",
|
||||
master_addr: str = "localhost",
|
||||
master_port: str = "29500",
|
||||
device_type: str = "cuda",
|
||||
device_ids: Optional[List[int]] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# clear environment variables
|
||||
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK', 'LOCAL_DEVICE']:
|
||||
for key in [
|
||||
"MASTER_ADDR",
|
||||
"MASTER_PORT",
|
||||
"RANK",
|
||||
"WORLD_SIZE",
|
||||
"LOCAL_RANK",
|
||||
"LOCAL_DEVICE",
|
||||
]:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
|
||||
if world_size == 1:
|
||||
device_ids = device_ids or [0]
|
||||
device_id = torch.device(device_type, device_ids[0])
|
||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||
|
||||
|
||||
func(**kwargs)
|
||||
return
|
||||
|
||||
wrapper_spawn_func_args = (world_size, backend, master_addr, master_port,
|
||||
device_type, device_ids, func, kwargs)
|
||||
wrapper_spawn_func_args = (
|
||||
world_size,
|
||||
backend,
|
||||
master_addr,
|
||||
master_port,
|
||||
device_type,
|
||||
device_ids,
|
||||
func,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
mp.spawn(
|
||||
wrapper_spawn_func,
|
||||
nprocs=world_size,
|
||||
args=wrapper_spawn_func_args,
|
||||
join=True
|
||||
)
|
||||
wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,15 +14,12 @@ from khaosz.trainer.train_callback import (
|
|||
__all__ = [
|
||||
# Main trainer
|
||||
"Trainer",
|
||||
|
||||
# Strategy factory
|
||||
"StrategyFactory",
|
||||
"BaseStrategy",
|
||||
|
||||
# Scheduler factory
|
||||
"SchedulerFactory",
|
||||
"BaseScheduler",
|
||||
|
||||
# Callbacks
|
||||
"TrainCallback",
|
||||
"GradientClippingCallback",
|
||||
|
|
@ -30,4 +27,4 @@ __all__ = [
|
|||
"CheckpointCallback",
|
||||
"ProgressBarCallback",
|
||||
"MetricLoggerCallback",
|
||||
]
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import torch.nn as nn
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
||||
""" Compute gradient norm for each parameter in the model. """
|
||||
"""Compute gradient norm for each parameter in the model."""
|
||||
norms = {}
|
||||
for name, param in model.named_parameters():
|
||||
norms[name] = 0.0
|
||||
|
|
@ -11,8 +12,9 @@ def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
|||
norms[name] = norm
|
||||
return norms
|
||||
|
||||
|
||||
def grad_std(model: nn.Module) -> Dict[str, float]:
|
||||
""" Compute standard deviation of gradients for each parameter. """
|
||||
"""Compute standard deviation of gradients for each parameter."""
|
||||
stds = {}
|
||||
for name, param in model.named_parameters():
|
||||
stds[name] = 0.0
|
||||
|
|
@ -21,41 +23,45 @@ def grad_std(model: nn.Module) -> Dict[str, float]:
|
|||
stds[name] = std
|
||||
return stds
|
||||
|
||||
|
||||
def grad_max(model: nn.Module) -> Dict[str, float]:
|
||||
""" Find the maximum absolute gradient value for each parameter. """
|
||||
"""Find the maximum absolute gradient value for each parameter."""
|
||||
max_vals = {}
|
||||
for name, param in model.named_parameters():
|
||||
max_vals[name] = -float('inf')
|
||||
max_vals[name] = -float("inf")
|
||||
if param.grad:
|
||||
max_val = param.grad.data.max().item()
|
||||
max_vals[name] = max_val
|
||||
|
||||
|
||||
return max_vals
|
||||
|
||||
|
||||
def grad_min(model: nn.Module) -> Dict[str, float]:
|
||||
""" Find the minimum absolute gradient value for each parameter. """
|
||||
"""Find the minimum absolute gradient value for each parameter."""
|
||||
min_vals = {}
|
||||
for name, param in model.named_parameters():
|
||||
min_vals[name] = float('inf')
|
||||
min_vals[name] = float("inf")
|
||||
if param.grad:
|
||||
min_val = param.grad.data.min().item()
|
||||
min_vals[name] = min_val
|
||||
|
||||
|
||||
return min_vals
|
||||
|
||||
|
||||
def grad_mean(model: nn.Module) -> Dict[str, float]:
|
||||
""" Compute mean of gradients for each parameter. """
|
||||
"""Compute mean of gradients for each parameter."""
|
||||
means = {}
|
||||
for name, param in model.named_parameters():
|
||||
means[name] = 0.0
|
||||
if param.grad:
|
||||
mean = param.grad.data.mean().item()
|
||||
means[name] = mean
|
||||
|
||||
|
||||
return means
|
||||
|
||||
|
||||
def grad_nan_num(model: nn.Module) -> Dict[str, int]:
|
||||
""" Count the number of NaNs in gradients for each parameter. """
|
||||
"""Count the number of NaNs in gradients for each parameter."""
|
||||
nan_nums = {}
|
||||
for name, param in model.named_parameters():
|
||||
nan_nums[name] = 0
|
||||
|
|
@ -64,26 +70,34 @@ def grad_nan_num(model: nn.Module) -> Dict[str, int]:
|
|||
nan_nums[name] = nan_num
|
||||
return nan_nums
|
||||
|
||||
|
||||
def ctx_get_loss(ctx):
|
||||
return ctx.loss
|
||||
|
||||
|
||||
def ctx_get_lr(ctx):
|
||||
return ctx.optimizer.param_groups[-1]['lr']
|
||||
return ctx.optimizer.param_groups[-1]["lr"]
|
||||
|
||||
|
||||
def ctx_get_grad_norm(ctx):
|
||||
return grad_norm(ctx.model)
|
||||
|
||||
|
||||
def ctx_get_grad_std(ctx):
|
||||
return grad_std(ctx.model)
|
||||
|
||||
|
||||
def ctx_get_grad_max(ctx):
|
||||
return grad_max(ctx.model)
|
||||
|
||||
|
||||
def ctx_get_grad_min(ctx):
|
||||
return grad_min(ctx.model)
|
||||
|
||||
|
||||
def ctx_get_grad_mean(ctx):
|
||||
return grad_mean(ctx.model)
|
||||
|
||||
|
||||
def ctx_get_grad_nan_num(ctx):
|
||||
return grad_nan_num(ctx.model)
|
||||
return grad_nan_num(ctx.model)
|
||||
|
|
|
|||
|
|
@ -9,71 +9,75 @@ from khaosz.config.schedule_config import ScheduleConfig
|
|||
|
||||
class BaseScheduler(LRScheduler, ABC):
|
||||
"""Base scheduler class for all other schedulers."""
|
||||
|
||||
|
||||
def __init__(self, optimizer, last_epoch: int = -1):
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_lr(self) -> List[float]:
|
||||
"""Calculate the current learning rate."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
return super().state_dict()
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]):
|
||||
super().load_state_dict(state_dict)
|
||||
|
||||
|
||||
class SchedulerFactory:
|
||||
"""Factory class for creating learning rate schedulers.
|
||||
|
||||
|
||||
Supports decorator-based registration for extensible scheduler types.
|
||||
Also supports creation from ScheduleConfig objects.
|
||||
|
||||
|
||||
Example usage:
|
||||
@SchedulerFactory.register("custom")
|
||||
class CustomScheduler(BaseScheduler):
|
||||
...
|
||||
|
||||
|
||||
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs)
|
||||
|
||||
|
||||
# Or from config
|
||||
config = CosineScheduleConfig(total_steps=10000)
|
||||
scheduler = SchedulerFactory.load(optimizer, config)
|
||||
"""
|
||||
|
||||
|
||||
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {}
|
||||
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
"""Decorator to register a new scheduler class.
|
||||
|
||||
|
||||
Args:
|
||||
name: Registration name for the scheduler
|
||||
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the scheduler class
|
||||
"""
|
||||
|
||||
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
|
||||
if not issubclass(scheduler_cls, BaseScheduler):
|
||||
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
||||
raise TypeError(
|
||||
f"{scheduler_cls.__name__} must inherit from BaseScheduler"
|
||||
)
|
||||
cls.SCHEDULER_MAP[name] = scheduler_cls
|
||||
return scheduler_cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler:
|
||||
"""Create a scheduler instance by type name.
|
||||
|
||||
|
||||
Args:
|
||||
optimizer: PyTorch optimizer
|
||||
schedule_type: Type of scheduler ("cosine", "sgdr")
|
||||
**kwargs: Arguments passed to the scheduler constructor
|
||||
|
||||
|
||||
Returns:
|
||||
Scheduler instance
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If schedule_type is not supported
|
||||
"""
|
||||
|
|
@ -82,25 +86,25 @@ class SchedulerFactory:
|
|||
f"Unknown schedule type: '{schedule_type}'. "
|
||||
f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}"
|
||||
)
|
||||
|
||||
|
||||
scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
|
||||
return scheduler_cls(optimizer, **kwargs)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
|
||||
"""Create a scheduler from a ScheduleConfig object.
|
||||
|
||||
|
||||
Args:
|
||||
optimizer: PyTorch optimizer
|
||||
schedule_config: ScheduleConfig instance
|
||||
|
||||
|
||||
Returns:
|
||||
Scheduler instance
|
||||
"""
|
||||
kwargs = schedule_config.get_kwargs()
|
||||
schedule_type = kwargs.pop("schedule_type")
|
||||
return SchedulerFactory.create(optimizer, schedule_type, **kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
def available_types(cls) -> list:
|
||||
"""Return list of registered scheduler type names."""
|
||||
|
|
@ -114,22 +118,21 @@ class SchedulerFactory:
|
|||
@SchedulerFactory.register("cosine")
|
||||
class CosineScheduler(BaseScheduler):
|
||||
"""Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler."""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
warmup_steps: int,
|
||||
lr_decay_steps: int,
|
||||
min_rate: float = 0.05,
|
||||
last_epoch: int = -1
|
||||
self,
|
||||
optimizer,
|
||||
warmup_steps: int,
|
||||
lr_decay_steps: int,
|
||||
min_rate: float = 0.05,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
self.warmup_steps = warmup_steps
|
||||
self.lr_decay_steps = lr_decay_steps
|
||||
self.min_rate = min_rate
|
||||
self.total_steps = warmup_steps + lr_decay_steps
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
|
||||
|
||||
def get_lr(self) -> List[float]:
|
||||
# warmup
|
||||
if self.last_epoch < self.warmup_steps:
|
||||
|
|
@ -142,46 +145,47 @@ class CosineScheduler(BaseScheduler):
|
|||
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress))
|
||||
decay_factor = max(self.min_rate, cosine_decay)
|
||||
return [base_lr * decay_factor for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
def state_dict(self):
|
||||
state = super().state_dict()
|
||||
state.update({
|
||||
'warmup_steps': self.warmup_steps,
|
||||
'lr_decay_steps': self.lr_decay_steps,
|
||||
'min_rate': self.min_rate,
|
||||
'total_steps': self.total_steps,
|
||||
})
|
||||
state.update(
|
||||
{
|
||||
"warmup_steps": self.warmup_steps,
|
||||
"lr_decay_steps": self.lr_decay_steps,
|
||||
"min_rate": self.min_rate,
|
||||
"total_steps": self.total_steps,
|
||||
}
|
||||
)
|
||||
return state
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.warmup_steps = state_dict.pop('warmup_steps')
|
||||
self.lr_decay_steps = state_dict.pop('lr_decay_steps')
|
||||
self.min_rate = state_dict.pop('min_rate')
|
||||
self.total_steps = state_dict.pop('total_steps')
|
||||
self.warmup_steps = state_dict.pop("warmup_steps")
|
||||
self.lr_decay_steps = state_dict.pop("lr_decay_steps")
|
||||
self.min_rate = state_dict.pop("min_rate")
|
||||
self.total_steps = state_dict.pop("total_steps")
|
||||
super().load_state_dict(state_dict)
|
||||
|
||||
|
||||
@SchedulerFactory.register("sgdr")
|
||||
class SGDRScheduler(BaseScheduler):
|
||||
"""SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler."""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
warmup_steps: int,
|
||||
cycle_length: int,
|
||||
min_rate: float = 0.05,
|
||||
t_mult: int = 2,
|
||||
last_epoch: int = -1,
|
||||
self,
|
||||
optimizer,
|
||||
warmup_steps: int,
|
||||
cycle_length: int,
|
||||
min_rate: float = 0.05,
|
||||
t_mult: int = 2,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
self.warmup_steps = warmup_steps
|
||||
self.cycle_length = cycle_length
|
||||
self.min_rate = min_rate
|
||||
self.t_mult = t_mult
|
||||
|
||||
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
|
||||
def get_lr(self):
|
||||
# warmup
|
||||
if self.last_epoch < self.warmup_steps:
|
||||
|
|
@ -190,40 +194,44 @@ class SGDRScheduler(BaseScheduler):
|
|||
|
||||
# SGDR
|
||||
steps_since_warmup = self.last_epoch - self.warmup_steps
|
||||
|
||||
|
||||
# 1. Calculate current cycle and position within cycle
|
||||
current_cycle_length = self.cycle_length
|
||||
total_cycles_length = 0
|
||||
cycle_num = 0
|
||||
|
||||
|
||||
while total_cycles_length + current_cycle_length <= steps_since_warmup:
|
||||
total_cycles_length += current_cycle_length
|
||||
current_cycle_length *= self.t_mult
|
||||
cycle_num += 1
|
||||
|
||||
|
||||
steps_in_cycle = steps_since_warmup - total_cycles_length
|
||||
|
||||
|
||||
# 2. Cosine annealing within the current cycle
|
||||
cosine_factor = 0.5 * (1 + math.cos(math.pi * steps_in_cycle / current_cycle_length))
|
||||
cosine_factor = 0.5 * (
|
||||
1 + math.cos(math.pi * steps_in_cycle / current_cycle_length)
|
||||
)
|
||||
learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor
|
||||
|
||||
|
||||
return [base_lr * learning_rate_factor for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the scheduler as a dict."""
|
||||
state = super().state_dict()
|
||||
state.update({
|
||||
'warmup_steps': self.warmup_steps,
|
||||
'cycle_length': self.cycle_length,
|
||||
'min_rate': self.min_rate,
|
||||
't_mult': self.t_mult
|
||||
})
|
||||
state.update(
|
||||
{
|
||||
"warmup_steps": self.warmup_steps,
|
||||
"cycle_length": self.cycle_length,
|
||||
"min_rate": self.min_rate,
|
||||
"t_mult": self.t_mult,
|
||||
}
|
||||
)
|
||||
return state
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads the scheduler's state."""
|
||||
self.warmup_steps = state_dict.pop('warmup_steps')
|
||||
self.cycle_length = state_dict.pop('cycle_length')
|
||||
self.min_rate = state_dict.pop('min_rate')
|
||||
self.t_mult = state_dict.pop('t_mult')
|
||||
super().load_state_dict(state_dict)
|
||||
self.warmup_steps = state_dict.pop("warmup_steps")
|
||||
self.cycle_length = state_dict.pop("cycle_length")
|
||||
self.min_rate = state_dict.pop("min_rate")
|
||||
self.t_mult = state_dict.pop("t_mult")
|
||||
super().load_state_dict(state_dict)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ def unwrap_model(model: nn.Module) -> nn.Module:
|
|||
|
||||
def create_ref_model(model: nn.Module) -> nn.Module:
|
||||
"""Create a reference model for DPO/GRPO training.
|
||||
|
||||
|
||||
Handles DDP-wrapped models safely by unwrapping first,
|
||||
then creating a deep copy with frozen gradients.
|
||||
"""
|
||||
|
|
@ -37,25 +37,27 @@ def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
|
|||
|
||||
|
||||
def get_logprobs(
|
||||
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
||||
input_ids: Tensor,
|
||||
mask: Tensor,
|
||||
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
||||
input_ids: Tensor,
|
||||
mask: Tensor,
|
||||
reduction: str,
|
||||
):
|
||||
"""Compute token-wise log probabilities from model outputs.
|
||||
|
||||
|
||||
Args:
|
||||
model: The language model
|
||||
input_ids: Input token IDs of shape [batch_size, seq_len]
|
||||
mask: Attention mask of shape [batch_size, seq_len]
|
||||
reduction: How to reduce over sequence dimension ("mean", "sum", "none")
|
||||
|
||||
|
||||
Returns:
|
||||
Log probabilities with reduction applied over sequence dimension
|
||||
"""
|
||||
allowed_reductions = ["mean", "sum", "none"]
|
||||
if reduction not in allowed_reductions:
|
||||
raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'")
|
||||
raise ValueError(
|
||||
f"reduction must be one of {allowed_reductions}, got '{reduction}'"
|
||||
)
|
||||
|
||||
shifted_input_ids = input_ids[:, 1:]
|
||||
shifted_mask = mask[:, 1:]
|
||||
|
|
@ -64,13 +66,13 @@ def get_logprobs(
|
|||
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||||
|
||||
token_logprobs = torch.gather(
|
||||
log_probs,
|
||||
dim=-1,
|
||||
index=shifted_input_ids.unsqueeze(-1)
|
||||
log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
|
||||
if reduction == "mean":
|
||||
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(dim=-1).clamp(min=1.0)
|
||||
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(
|
||||
dim=-1
|
||||
).clamp(min=1.0)
|
||||
elif reduction == "sum":
|
||||
return (token_logprobs * shifted_mask).sum(dim=-1)
|
||||
else:
|
||||
|
|
@ -79,23 +81,25 @@ def get_logprobs(
|
|||
|
||||
class BaseStrategy(ABC):
|
||||
"""Abstract base class for training strategies."""
|
||||
|
||||
def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str):
|
||||
|
||||
def __init__(
|
||||
self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str
|
||||
):
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
"""Compute loss for the given batch.
|
||||
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing batch tensors
|
||||
|
||||
|
||||
Returns:
|
||||
Computed loss tensor
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
"""Allow calling strategy directly as a callable."""
|
||||
return self.compute_loss(batch)
|
||||
|
|
@ -103,51 +107,55 @@ class BaseStrategy(ABC):
|
|||
|
||||
class StrategyFactory:
|
||||
"""Factory class for creating training strategy instances.
|
||||
|
||||
|
||||
Supports decorator-based registration for extensible strategy types.
|
||||
All default strategies (seq, sft, dpo, grpo) are automatically registered.
|
||||
|
||||
|
||||
Example usage:
|
||||
@StrategyFactory.register("custom")
|
||||
class CustomStrategy(BaseStrategy):
|
||||
...
|
||||
|
||||
|
||||
strategy = StrategyFactory.create(model, "custom", device)
|
||||
"""
|
||||
|
||||
|
||||
SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"})
|
||||
STRATEGY_MAP: Dict[str, type] = {}
|
||||
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
"""Decorator to register a new strategy class.
|
||||
|
||||
|
||||
Args:
|
||||
name: Registration name for the strategy
|
||||
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the strategy class
|
||||
"""
|
||||
|
||||
def decorator(strategy_cls: type) -> type:
|
||||
if not issubclass(strategy_cls, BaseStrategy):
|
||||
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
||||
raise TypeError(
|
||||
f"{strategy_cls.__name__} must inherit from BaseStrategy"
|
||||
)
|
||||
cls.STRATEGY_MAP[name] = strategy_cls
|
||||
return strategy_cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy:
|
||||
"""Create a strategy instance based on training type.
|
||||
|
||||
|
||||
Args:
|
||||
model: Model instance for the strategy
|
||||
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
||||
device: Device to run the strategy on
|
||||
**kwargs: Additional arguments passed to strategy constructor
|
||||
|
||||
|
||||
Returns:
|
||||
Strategy instance
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If train_type is not supported
|
||||
NotImplementedError: If train_type is in supported list but not implemented
|
||||
|
|
@ -157,15 +165,15 @@ class StrategyFactory:
|
|||
f"Unknown training strategy: '{train_type}'. "
|
||||
f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}"
|
||||
)
|
||||
|
||||
|
||||
if train_type not in cls.STRATEGY_MAP:
|
||||
raise NotImplementedError(
|
||||
f"Strategy '{train_type}' is supported but not yet implemented."
|
||||
)
|
||||
|
||||
|
||||
strategy_cls = cls.STRATEGY_MAP[train_type]
|
||||
return strategy_cls(model, device, **kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
def available_strategies(cls) -> list:
|
||||
"""Return list of registered strategy names."""
|
||||
|
|
@ -179,77 +187,81 @@ class StrategyFactory:
|
|||
@StrategyFactory.register("seq")
|
||||
class SEQStrategy(BaseStrategy):
|
||||
"""Standard next-token prediction training strategy.
|
||||
|
||||
|
||||
Computes cross-entropy loss for next token prediction.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, model, device, label_smoothing: float = 0.0):
|
||||
super().__init__(model, device)
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
batch = move_to_device(batch, self.device)
|
||||
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
||||
logits = self.model(input_ids=input_ids)["logits"]
|
||||
|
||||
|
||||
loss = F.cross_entropy(
|
||||
input=logits.flatten(0, 1).float(),
|
||||
target=target_ids.flatten(),
|
||||
label_smoothing=self.label_smoothing
|
||||
label_smoothing=self.label_smoothing,
|
||||
)
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
@StrategyFactory.register("sft")
|
||||
class SFTStrategy(BaseStrategy):
|
||||
"""Supervised Fine-tuning strategy with loss masking.
|
||||
|
||||
|
||||
Applies cross-entropy loss only to tokens where loss_mask is True.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, model, device, label_smoothing: float = 0.0):
|
||||
super().__init__(model, device)
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
batch = move_to_device(batch, self.device)
|
||||
input_ids, target_ids, loss_mask = batch["input_ids"], batch["target_ids"], batch["loss_mask"]
|
||||
|
||||
input_ids, target_ids, loss_mask = (
|
||||
batch["input_ids"],
|
||||
batch["target_ids"],
|
||||
batch["loss_mask"],
|
||||
)
|
||||
|
||||
ignore_index = -100
|
||||
logits = self.model(input_ids=input_ids)["logits"]
|
||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||
|
||||
|
||||
loss = F.cross_entropy(
|
||||
input=logits.flatten(0, 1).float(),
|
||||
target=target_ids.flatten(),
|
||||
ignore_index=ignore_index,
|
||||
label_smoothing=self.label_smoothing
|
||||
label_smoothing=self.label_smoothing,
|
||||
)
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
@StrategyFactory.register("dpo")
|
||||
class DPOStrategy(BaseStrategy):
|
||||
"""Direct Preference Optimization strategy.
|
||||
|
||||
|
||||
Implements the DPO loss from the paper "Direct Preference Optimization".
|
||||
Uses a reference model to compute KL divergence penalty.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
device: str,
|
||||
beta: float = 0.1,
|
||||
reduction: str = "mean",
|
||||
):
|
||||
self,
|
||||
model: nn.Module,
|
||||
device: str,
|
||||
beta: float = 0.1,
|
||||
reduction: str = "mean",
|
||||
):
|
||||
super().__init__(model, device)
|
||||
self.ref_model = create_ref_model(model)
|
||||
self.beta = beta
|
||||
self.reduction = reduction
|
||||
|
||||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
batch = move_to_device(batch, self.device)
|
||||
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
|
||||
|
|
@ -257,17 +269,19 @@ class DPOStrategy(BaseStrategy):
|
|||
|
||||
contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
|
||||
contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
|
||||
|
||||
|
||||
log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction)
|
||||
|
||||
with torch.no_grad():
|
||||
log_ref = get_logprobs(self.ref_model, contact_ids, contact_mask, self.reduction)
|
||||
|
||||
log_pi_chosen = log_pi[:chosen_ids.shape[0]]
|
||||
log_pi_rejected = log_pi[chosen_ids.shape[0]:]
|
||||
log_ref_chosen = log_ref[:chosen_ids.shape[0]]
|
||||
log_ref_rejected = log_ref[chosen_ids.shape[0]:]
|
||||
|
||||
log_ref = get_logprobs(
|
||||
self.ref_model, contact_ids, contact_mask, self.reduction
|
||||
)
|
||||
|
||||
log_pi_chosen = log_pi[: chosen_ids.shape[0]]
|
||||
log_pi_rejected = log_pi[chosen_ids.shape[0] :]
|
||||
log_ref_chosen = log_ref[: chosen_ids.shape[0]]
|
||||
log_ref_rejected = log_ref[chosen_ids.shape[0] :]
|
||||
|
||||
pi_log_ratio = log_pi_chosen - log_pi_rejected
|
||||
ref_log_ratio = log_ref_chosen - log_ref_rejected
|
||||
|
||||
|
|
@ -280,14 +294,14 @@ class DPOStrategy(BaseStrategy):
|
|||
@StrategyFactory.register("grpo")
|
||||
class GRPOStrategy(BaseStrategy):
|
||||
"""Group Relative Policy Optimization strategy.
|
||||
|
||||
|
||||
Implements GRPO with clipping and KL penalty.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
device: str,
|
||||
self,
|
||||
model: nn.Module,
|
||||
device: str,
|
||||
clip_eps: float = 0.2,
|
||||
kl_coef: float = 0.01,
|
||||
group_size: int = 4,
|
||||
|
|
@ -299,43 +313,47 @@ class GRPOStrategy(BaseStrategy):
|
|||
self.kl_coef = kl_coef
|
||||
self.group_size = group_size
|
||||
self.reduction = reduction
|
||||
|
||||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
batch = move_to_device(batch, self.device)
|
||||
prompts = batch["prompts"]
|
||||
responses = batch["responses"]
|
||||
masks = batch["masks"]
|
||||
rewards = batch["rewards"]
|
||||
|
||||
|
||||
batch_size, group_size, response_len = responses.shape
|
||||
responses_flat = responses.view(-1, response_len)
|
||||
masks_flat = masks.view(-1, response_len)
|
||||
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
|
||||
|
||||
|
||||
# Shape: (batch_size * group_size, seq_len + response_len)
|
||||
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
|
||||
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
|
||||
|
||||
log_probs_policy = get_logprobs(self.model, full_sequences, full_masks, self.reduction)
|
||||
|
||||
log_probs_policy = get_logprobs(
|
||||
self.model, full_sequences, full_masks, self.reduction
|
||||
)
|
||||
log_probs_policy = log_probs_policy.view(batch_size, group_size)
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction)
|
||||
log_probs_ref = get_logprobs(
|
||||
self.ref_model, full_sequences, full_masks, self.reduction
|
||||
)
|
||||
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
||||
|
||||
|
||||
# Compute advantages from rewards with normalization
|
||||
eps = torch.finfo(log_probs_policy.dtype).eps
|
||||
mean = rewards.mean(dim=-1, keepdim=True)
|
||||
std = rewards.std(dim=-1, keepdim=True)
|
||||
advantages = (rewards - mean) / (std + eps)
|
||||
|
||||
|
||||
# PPO-style clipped surrogate objective
|
||||
ratio = torch.exp(0) # Off-policy: policy_model = old_model
|
||||
surr1 = ratio * advantages
|
||||
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||
|
||||
|
||||
policy_loss = -torch.min(surr1, surr2).mean()
|
||||
kl_penalty = self.kl_coef * (log_probs_policy - log_probs_ref).square().mean()
|
||||
total_loss = policy_loss + kl_penalty
|
||||
|
||||
|
||||
return total_loss
|
||||
|
|
|
|||
|
|
@ -18,52 +18,53 @@ from khaosz.trainer.metric_util import (
|
|||
ctx_get_grad_norm,
|
||||
ctx_get_grad_mean,
|
||||
ctx_get_grad_std,
|
||||
ctx_get_grad_nan_num
|
||||
ctx_get_grad_nan_num,
|
||||
)
|
||||
from khaosz.data.serialization import Checkpoint
|
||||
from khaosz.trainer.train_context import TrainContext
|
||||
|
||||
|
||||
class TrainCallback(Protocol):
|
||||
"""
|
||||
"""
|
||||
Callback interface for trainer.
|
||||
"""
|
||||
|
||||
def on_train_begin(self, context: TrainContext):
|
||||
""" Called at the beginning of training. """
|
||||
|
||||
"""Called at the beginning of training."""
|
||||
|
||||
def on_train_end(self, context: TrainContext):
|
||||
""" Called at the end of training. """
|
||||
"""Called at the end of training."""
|
||||
|
||||
def on_epoch_begin(self, context: TrainContext):
|
||||
""" Called at the beginning of each epoch. """
|
||||
"""Called at the beginning of each epoch."""
|
||||
|
||||
def on_epoch_end(self, context: TrainContext):
|
||||
""" Called at the end of each epoch. """
|
||||
|
||||
"""Called at the end of each epoch."""
|
||||
|
||||
def on_step_begin(self, context: TrainContext):
|
||||
""" Called at the beginning of each step. """
|
||||
"""Called at the beginning of each step."""
|
||||
|
||||
def on_step_end(self, context: TrainContext):
|
||||
""" Called at the end of each step."""
|
||||
"""Called at the end of each step."""
|
||||
|
||||
def on_batch_begin(self, context: TrainContext):
|
||||
""" Called at the beginning of each batch. """
|
||||
"""Called at the beginning of each batch."""
|
||||
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
""" Called at the end of each batch. """
|
||||
|
||||
"""Called at the end of each batch."""
|
||||
|
||||
def on_error(self, context: TrainContext):
|
||||
""" Called when an error occurs during training. """
|
||||
"""Called when an error occurs during training."""
|
||||
|
||||
|
||||
class GradientClippingCallback(TrainCallback):
|
||||
"""
|
||||
"""
|
||||
Gradient clipping callback for trainer.
|
||||
"""
|
||||
|
||||
def __init__(self, max_grad_norm: float):
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
|
||||
def on_step_begin(self, context: TrainContext):
|
||||
_ = context
|
||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||
|
|
@ -73,86 +74,95 @@ class SchedulerCallback(TrainCallback):
|
|||
"""
|
||||
Scheduler callback for trainer.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def on_train_begin(self, context: TrainContext):
|
||||
for group in context.optimizer.param_groups:
|
||||
if "initial_lr" not in group:
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
if context.scheduler:
|
||||
context.scheduler.step()
|
||||
|
||||
|
||||
class CheckpointCallback(TrainCallback):
|
||||
"""
|
||||
"""
|
||||
Checkpoint callback for trainer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: str,
|
||||
self,
|
||||
save_dir: str,
|
||||
interval: int,
|
||||
weight_only: bool = False,
|
||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None
|
||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
||||
):
|
||||
self.save_dir = save_dir
|
||||
self.interval = interval
|
||||
self.weight_only = weight_only
|
||||
self.state_dict_fn = state_dict_fn
|
||||
self.last_ckpt_iter = 0
|
||||
|
||||
|
||||
@only_on_rank(0)
|
||||
def _save_checkpoint(self, context: TrainContext):
|
||||
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
|
||||
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict()
|
||||
|
||||
save_path = os.path.join(
|
||||
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
||||
)
|
||||
state_dict = (
|
||||
self.state_dict_fn(context.model)
|
||||
if self.state_dict_fn
|
||||
else context.model.state_dict()
|
||||
)
|
||||
|
||||
context.checkpoint = Checkpoint(
|
||||
state_dict=state_dict,
|
||||
epoch=context.epoch,
|
||||
iteration=context.iteration
|
||||
state_dict=state_dict, epoch=context.epoch, iteration=context.iteration
|
||||
)
|
||||
|
||||
context.checkpoint.save(save_path)
|
||||
self.last_ckpt_iter = context.iteration
|
||||
|
||||
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
if context.iteration - self.last_ckpt_iter >= self.interval:
|
||||
self._save_checkpoint(context)
|
||||
|
||||
|
||||
def on_train_end(self, context: TrainContext):
|
||||
if context.iteration != self.last_ckpt_iter:
|
||||
self._save_checkpoint(context)
|
||||
|
||||
|
||||
def on_error(self, context: TrainContext):
|
||||
self._save_checkpoint(context)
|
||||
|
||||
|
||||
class ProgressBarCallback(TrainCallback):
|
||||
"""
|
||||
"""
|
||||
Progress bar callback for trainer.
|
||||
"""
|
||||
|
||||
def __init__(self, num_epoch: int):
|
||||
self.num_epoch = num_epoch
|
||||
self.progress_bar: tqdm = None
|
||||
|
||||
|
||||
@only_on_rank(0)
|
||||
def on_epoch_begin(self, context: TrainContext):
|
||||
self.progress_bar = tqdm(
|
||||
context.dataloader,
|
||||
desc=f"Epoch {context.epoch+1}/{self.num_epoch}",
|
||||
dynamic_ncols=True
|
||||
context.dataloader,
|
||||
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
|
||||
|
||||
@only_on_rank(0)
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
self.progress_bar.set_postfix({
|
||||
"loss": f"{context.loss:.4f}",
|
||||
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}"
|
||||
})
|
||||
self.progress_bar.set_postfix(
|
||||
{
|
||||
"loss": f"{context.loss:.4f}",
|
||||
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
|
||||
}
|
||||
)
|
||||
self.progress_bar.update(1)
|
||||
|
||||
|
||||
@only_on_rank(0)
|
||||
def on_epoch_end(self, context: TrainContext):
|
||||
_ = context
|
||||
|
|
@ -162,66 +172,65 @@ class ProgressBarCallback(TrainCallback):
|
|||
|
||||
class MetricLoggerCallback(TrainCallback):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir:str,
|
||||
save_interval:int,
|
||||
log_interval:int=10,
|
||||
metrics:List[str]=None
|
||||
self,
|
||||
log_dir: str,
|
||||
save_interval: int,
|
||||
log_interval: int = 10,
|
||||
metrics: List[str] = None,
|
||||
):
|
||||
self.last_log_iter = 0
|
||||
self.save_interval = save_interval
|
||||
self.log_interval = log_interval
|
||||
self.metrics = metrics or ['loss', 'lr']
|
||||
|
||||
self.metrics = metrics or ["loss", "lr"]
|
||||
|
||||
self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs"
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
self.log_cache = []
|
||||
|
||||
|
||||
self._metric_funcs = {
|
||||
'loss': ctx_get_loss,
|
||||
'lr': ctx_get_lr,
|
||||
'grad_norm': ctx_get_grad_norm,
|
||||
'grad_std': ctx_get_grad_std,
|
||||
'grad_max': ctx_get_grad_max,
|
||||
'grad_min': ctx_get_grad_min,
|
||||
'grad_mean': ctx_get_grad_mean,
|
||||
'grad_nan_num': ctx_get_grad_nan_num
|
||||
"loss": ctx_get_loss,
|
||||
"lr": ctx_get_lr,
|
||||
"grad_norm": ctx_get_grad_norm,
|
||||
"grad_std": ctx_get_grad_std,
|
||||
"grad_max": ctx_get_grad_max,
|
||||
"grad_min": ctx_get_grad_min,
|
||||
"grad_mean": ctx_get_grad_mean,
|
||||
"grad_nan_num": ctx_get_grad_nan_num,
|
||||
}
|
||||
|
||||
def _get_log_data(self, context: TrainContext):
|
||||
return {
|
||||
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"epoch": context.epoch,
|
||||
"iter": context.iteration,
|
||||
**{m: self._metric_funcs[m](context) for m in self.metrics}
|
||||
**{m: self._metric_funcs[m](context) for m in self.metrics},
|
||||
}
|
||||
|
||||
|
||||
@only_on_rank(0)
|
||||
def _add_log(self, log_data):
|
||||
self.log_cache.append(log_data)
|
||||
|
||||
|
||||
@only_on_rank(0)
|
||||
def _save_log(self, epoch, iter):
|
||||
log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl"
|
||||
|
||||
with open(log_file, 'w') as f:
|
||||
|
||||
with open(log_file, "w") as f:
|
||||
for log in self.log_cache:
|
||||
f.write(json.dumps(log) + '\n')
|
||||
|
||||
f.write(json.dumps(log) + "\n")
|
||||
|
||||
def on_batch_end(self, context):
|
||||
if context.iteration % self.log_interval == 0:
|
||||
log_data = self._get_log_data(context)
|
||||
self._add_log(log_data)
|
||||
|
||||
|
||||
if context.iteration - self.last_log_iter >= self.save_interval:
|
||||
self._save_log(context.epoch, context.iteration)
|
||||
self.last_log_iter = context.iteration
|
||||
|
||||
|
||||
def on_train_end(self, context):
|
||||
if context.iteration != self.last_log_iter:
|
||||
self._save_log(context.epoch, context.iteration)
|
||||
|
||||
|
||||
def on_error(self, context):
|
||||
self._save_log(context.epoch, context.iteration)
|
||||
|
||||
|
|
@ -21,11 +21,11 @@ class TrainContext:
|
|||
optimizer: Optimizer = field(default=None)
|
||||
scheduler: LRScheduler = field(default=None)
|
||||
checkpoint: Checkpoint = field(default=None)
|
||||
|
||||
|
||||
epoch: int = field(default=0)
|
||||
iteration: int = field(default=0)
|
||||
loss: float = field(default=0.0)
|
||||
|
||||
|
||||
world_size: int = field(default=1)
|
||||
rank: int = field(default=0)
|
||||
kwargs: dict = field(default_factory=dict)
|
||||
|
|
@ -39,17 +39,17 @@ class TrainContextBuilder:
|
|||
world_size=get_world_size(),
|
||||
rank=get_rank(),
|
||||
)
|
||||
|
||||
|
||||
device = get_current_device()
|
||||
self._context.model = self._context.model.to(device=device)
|
||||
|
||||
|
||||
if self.config.nprocs > 1:
|
||||
fn = self.config.parallel_wrapper
|
||||
self._context.model = fn(self._context.model)
|
||||
|
||||
|
||||
self._context.optimizer = self.config.optimizer_fn(self._context.model)
|
||||
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
|
||||
|
||||
|
||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
if checkpoint is None:
|
||||
checkpoint = Checkpoint(
|
||||
|
|
@ -60,10 +60,10 @@ class TrainContextBuilder:
|
|||
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
||||
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
||||
self._context.model.load_state_dict(checkpoint.state_dict)
|
||||
|
||||
|
||||
self._context.checkpoint = checkpoint
|
||||
return self
|
||||
|
||||
|
||||
def with_dataloader(self) -> Self:
|
||||
# fix: change batch level iteration to sample level offset
|
||||
config = self.config
|
||||
|
|
@ -72,28 +72,28 @@ class TrainContextBuilder:
|
|||
data_source=config.dataset,
|
||||
start_epoch=self._context.epoch,
|
||||
start_iter=sampler_offset,
|
||||
seed=config.random_seed
|
||||
seed=config.random_seed,
|
||||
)
|
||||
|
||||
|
||||
dataloader = DataLoader(
|
||||
config.dataset,
|
||||
batch_size=config.batch_size,
|
||||
config.dataset,
|
||||
batch_size=config.batch_size,
|
||||
sampler=resumeable_sampler,
|
||||
num_workers=config.num_workers,
|
||||
pin_memory=config.pin_memory,
|
||||
prefetch_factor=config.prefetch_factor
|
||||
prefetch_factor=config.prefetch_factor,
|
||||
)
|
||||
self._context.dataloader = dataloader
|
||||
return self
|
||||
|
||||
|
||||
def with_strategy(self) -> Self:
|
||||
self._context.strategy = StrategyFactory.create(
|
||||
model=self._context.model,
|
||||
train_type=self.config.strategy,
|
||||
device=get_current_device(),
|
||||
**self.config.extra_kwargs
|
||||
**self.config.extra_kwargs,
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
def build(self) -> TrainContext:
|
||||
return self._context
|
||||
return self._context
|
||||
|
|
|
|||
|
|
@ -2,12 +2,12 @@ import logging
|
|||
from typing import Optional, List
|
||||
from khaosz.config import TrainConfig
|
||||
from khaosz.trainer.train_callback import (
|
||||
TrainCallback,
|
||||
ProgressBarCallback,
|
||||
TrainCallback,
|
||||
ProgressBarCallback,
|
||||
CheckpointCallback,
|
||||
MetricLoggerCallback,
|
||||
GradientClippingCallback,
|
||||
SchedulerCallback
|
||||
SchedulerCallback,
|
||||
)
|
||||
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
||||
from khaosz.data.serialization import Checkpoint
|
||||
|
|
@ -18,37 +18,39 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class Trainer:
|
||||
def __init__(
|
||||
self,
|
||||
train_config: TrainConfig,
|
||||
callbacks: Optional[List[TrainCallback]] = None
|
||||
self, train_config: TrainConfig, callbacks: Optional[List[TrainCallback]] = None
|
||||
):
|
||||
self.train_config = train_config
|
||||
default_callbacks = self._get_default_callbacks()
|
||||
self.callbacks = default_callbacks + callbacks if callbacks else default_callbacks
|
||||
self.callbacks = (
|
||||
default_callbacks + callbacks if callbacks else default_callbacks
|
||||
)
|
||||
|
||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||
train_config = self.train_config
|
||||
return [
|
||||
ProgressBarCallback(train_config.n_epoch),
|
||||
CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
|
||||
MetricLoggerCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
|
||||
CheckpointCallback(train_config.ckpt_dir, train_config.ckpt_interval),
|
||||
MetricLoggerCallback(train_config.ckpt_dir, train_config.ckpt_interval),
|
||||
GradientClippingCallback(train_config.max_grad_norm),
|
||||
SchedulerCallback(),
|
||||
]
|
||||
|
||||
|
||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||
return (TrainContextBuilder(self.train_config)
|
||||
.with_checkpoint(checkpoint)
|
||||
.with_dataloader()
|
||||
.with_strategy()
|
||||
.build())
|
||||
|
||||
return (
|
||||
TrainContextBuilder(self.train_config)
|
||||
.with_checkpoint(checkpoint)
|
||||
.with_dataloader()
|
||||
.with_strategy()
|
||||
.build()
|
||||
)
|
||||
|
||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||
for callback in self.callbacks:
|
||||
method = getattr(callback, method_name, None)
|
||||
if method:
|
||||
method(context)
|
||||
|
||||
|
||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||
config = self.train_config
|
||||
spawn_parallel_fn(
|
||||
|
|
@ -59,45 +61,45 @@ class Trainer:
|
|||
master_port=config.master_port,
|
||||
device_type=config.device_type,
|
||||
device_ids=config.device_ids,
|
||||
checkpoint=checkpoint
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
||||
context = self._build_context(checkpoint)
|
||||
self._call_callbacks('on_train_begin', context)
|
||||
|
||||
self._call_callbacks("on_train_begin", context)
|
||||
|
||||
try:
|
||||
context.model.train()
|
||||
# 1.epoch
|
||||
for epoch in range(context.epoch, self.train_config.n_epoch):
|
||||
context.epoch = epoch
|
||||
self._call_callbacks('on_epoch_begin', context)
|
||||
|
||||
self._call_callbacks("on_epoch_begin", context)
|
||||
|
||||
for batch in context.dataloader:
|
||||
# 3. batch
|
||||
self._call_callbacks('on_batch_begin', context)
|
||||
self._call_callbacks("on_batch_begin", context)
|
||||
loss = context.strategy(batch)
|
||||
context.loss = loss.item()
|
||||
context.iteration += 1
|
||||
|
||||
|
||||
# to make the loss normalized by accumulation steps
|
||||
stand_loss = loss / self.train_config.accumulation_steps
|
||||
stand_loss.backward()
|
||||
|
||||
self._call_callbacks('on_batch_end', context)
|
||||
|
||||
self._call_callbacks("on_batch_end", context)
|
||||
|
||||
if context.iteration % self.train_config.accumulation_steps == 0:
|
||||
# 2. step
|
||||
self._call_callbacks('on_step_begin', context)
|
||||
self._call_callbacks("on_step_begin", context)
|
||||
context.optimizer.step()
|
||||
context.optimizer.zero_grad()
|
||||
self._call_callbacks('on_step_end', context)
|
||||
self._call_callbacks("on_step_end", context)
|
||||
|
||||
self._call_callbacks("on_epoch_end", context)
|
||||
|
||||
self._call_callbacks('on_epoch_end', context)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
||||
self._call_callbacks('on_error', context)
|
||||
self._call_callbacks("on_error", context)
|
||||
raise
|
||||
finally:
|
||||
self._call_callbacks('on_train_end', context)
|
||||
self._call_callbacks("on_train_end", context)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ classifiers = [
|
|||
urls = { Homepage = "https://github.com/ViperEkura/KHAOSZ" }
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pytest==9.0.2"]
|
||||
dev = ["pytest==9.0.2", "ruff"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
|
|
@ -35,4 +35,13 @@ where = ["."]
|
|||
extra-index-url = "https://download.pytorch.org/whl/cu126"
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = { attr = "khaosz.__version__" }
|
||||
version = { attr = "khaosz.__version__" }
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
indent-style = "space"
|
||||
skip-magic-trailing-comma = false
|
||||
line-ending = "auto"
|
||||
|
|
@ -17,14 +17,14 @@ class RandomDataset(Dataset):
|
|||
self.length = length or int(np.random.randint(100, 200))
|
||||
self.max_length = max_length
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
"input_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
|
||||
"target_ids": torch.randint(0, self.vocab_size, (self.max_length,))
|
||||
"target_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -33,10 +33,10 @@ class MultiTurnDataset(Dataset):
|
|||
self.length = length or int(np.random.randint(100, 200))
|
||||
self.max_length = max_length
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
input_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
||||
target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
||||
|
|
@ -54,18 +54,18 @@ class EarlyStoppingDataset(Dataset):
|
|||
self.length = length
|
||||
self.stop_after = stop_after
|
||||
self.count = 0
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
self.count += 1
|
||||
if self.count == self.stop_after:
|
||||
raise RuntimeError("Simulated early stopping")
|
||||
|
||||
|
||||
return {
|
||||
"input_ids": torch.randint(0, 1000, (64,)),
|
||||
"target_ids": torch.randint(0, 1000, (64,))
|
||||
"target_ids": torch.randint(0, 1000, (64,)),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -74,10 +74,10 @@ def base_test_env(request: pytest.FixtureRequest):
|
|||
func_name = request.function.__name__
|
||||
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
|
||||
config_path = os.path.join(test_dir, "config.json")
|
||||
|
||||
|
||||
n_dim_choices = [8, 16, 32]
|
||||
n_head_choices = [2, 4]
|
||||
|
||||
|
||||
dim = int(np.random.choice(n_dim_choices))
|
||||
n_heads = int(np.random.choice(n_head_choices))
|
||||
n_kv_heads = n_heads // 2
|
||||
|
|
@ -91,16 +91,16 @@ def base_test_env(request: pytest.FixtureRequest):
|
|||
"dim_ffn": dim_ffn,
|
||||
"max_len": 1024,
|
||||
"n_layers": 4,
|
||||
"norm_eps": 1e-5
|
||||
"norm_eps": 1e-5,
|
||||
}
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
transformer_config = ModelConfig().load(config_path)
|
||||
model = Transformer(transformer_config).to(device=device)
|
||||
tokenizer = BpeTokenizer()
|
||||
|
||||
|
||||
yield {
|
||||
"device": device,
|
||||
"test_dir": str(test_dir),
|
||||
|
|
@ -109,20 +109,23 @@ def base_test_env(request: pytest.FixtureRequest):
|
|||
"model": model,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
|
||||
|
||||
shutil.rmtree(test_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random_dataset():
|
||||
dataset = RandomDataset()
|
||||
yield dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multi_turn_dataset():
|
||||
dataset = MultiTurnDataset()
|
||||
yield dataset
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def early_stopping_dataset():
|
||||
dataset = EarlyStoppingDataset()
|
||||
yield dataset
|
||||
yield dataset
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
|
|||
from khaosz.data.serialization import Checkpoint
|
||||
from khaosz.parallel.setup import get_rank, spawn_parallel_fn
|
||||
|
||||
|
||||
def test_single_process():
|
||||
model = torch.nn.Linear(10, 5)
|
||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||
|
|
@ -14,34 +15,31 @@ def test_single_process():
|
|||
|
||||
for epoch in range(3):
|
||||
for iteration in range(10):
|
||||
|
||||
x = torch.randn(32, 10)
|
||||
y = torch.randn(32, 5)
|
||||
loss = model(x).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
scheduler.step()
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
state_dict=model.state_dict(),
|
||||
epoch=3,
|
||||
iteration=30
|
||||
)
|
||||
|
||||
|
||||
checkpoint = Checkpoint(state_dict=model.state_dict(), epoch=3, iteration=30)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
checkpoint.save(tmpdir)
|
||||
|
||||
|
||||
loaded_checkpoint = Checkpoint.load(tmpdir)
|
||||
|
||||
|
||||
assert loaded_checkpoint.epoch == 3
|
||||
assert loaded_checkpoint.iteration == 30
|
||||
|
||||
|
||||
def simple_training():
|
||||
model = torch.nn.Linear(10, 5)
|
||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=10)
|
||||
|
||||
|
||||
for epoch in range(2):
|
||||
for iteration in range(5):
|
||||
x = torch.randn(16, 10)
|
||||
|
|
@ -57,28 +55,23 @@ def simple_training():
|
|||
epoch=2,
|
||||
iteration=10,
|
||||
)
|
||||
|
||||
|
||||
rank = get_rank()
|
||||
|
||||
|
||||
if rank == 0:
|
||||
shared_dir = tempfile.mkdtemp()
|
||||
checkpoint.save(shared_dir)
|
||||
else:
|
||||
shared_dir = None
|
||||
|
||||
|
||||
if dist.is_initialized():
|
||||
dir_list = [shared_dir]
|
||||
dist.broadcast_object_list(dir_list, src=0)
|
||||
shared_dir = dir_list[0]
|
||||
|
||||
|
||||
|
||||
loaded = Checkpoint.load(shared_dir)
|
||||
assert loaded.epoch == 2
|
||||
|
||||
|
||||
def test_multi_process():
|
||||
spawn_parallel_fn(
|
||||
simple_training,
|
||||
world_size=2,
|
||||
backend="gloo"
|
||||
)
|
||||
spawn_parallel_fn(simple_training, world_size=2, backend="gloo")
|
||||
|
|
|
|||
|
|
@ -5,30 +5,32 @@ from khaosz.data.serialization import save_h5
|
|||
from khaosz.data.dataset import *
|
||||
|
||||
|
||||
|
||||
def test_dataset_loader_random_paths(base_test_env):
|
||||
"""Test dataset loader with multiple random paths"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
|
||||
# Create multiple mmap dataset directories with random data
|
||||
num_files = np.random.randint(2, 5)
|
||||
|
||||
|
||||
for i in range(num_files):
|
||||
seq_length = np.random.randint(200, 400)
|
||||
dummy_data = {
|
||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64) for _ in range(10)],
|
||||
"sequence": [
|
||||
torch.randint(0, 1000, (seq_length,), dtype=torch.int64)
|
||||
for _ in range(10)
|
||||
],
|
||||
}
|
||||
save_h5(test_dir, f"data_{i}", dummy_data)
|
||||
|
||||
|
||||
# Test loading with multiple paths
|
||||
loaded_dataset = DatasetLoader.load(
|
||||
train_type="seq",
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
train_type="seq",
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
)
|
||||
assert loaded_dataset is not None
|
||||
assert len(loaded_dataset) > 0
|
||||
|
||||
|
||||
# Test that we can get items without errors
|
||||
for i in range(len(loaded_dataset)):
|
||||
item = loaded_dataset[i]
|
||||
|
|
@ -41,30 +43,30 @@ def test_dataset_loader_random_paths(base_test_env):
|
|||
def test_dpo_strategy_with_random_data(base_test_env):
|
||||
"""Test DPO strategy with randomized preference data"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
|
||||
# Create DPO-style data with memory mapping format
|
||||
seq_length = np.random.randint(100, 200)
|
||||
|
||||
|
||||
dummy_data = {
|
||||
"chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||
"rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||
"chosen_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
||||
"rejected_mask": [torch.ones(seq_length, dtype=torch.bool)]
|
||||
"rejected_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
||||
}
|
||||
|
||||
|
||||
save_h5(test_dir, "dpo_data", dummy_data)
|
||||
|
||||
|
||||
# Load DPO dataset
|
||||
dpo_dataset = DatasetLoader.load(
|
||||
train_type="dpo",
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
train_type="dpo",
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
)
|
||||
|
||||
|
||||
assert dpo_dataset is not None
|
||||
assert hasattr(dpo_dataset, 'fetcher')
|
||||
assert hasattr(dpo_dataset, "fetcher")
|
||||
assert len(dpo_dataset) > 0
|
||||
|
||||
|
||||
# Test that we can get DPO items without errors
|
||||
for i in range(min(3, len(dpo_dataset))):
|
||||
item = dpo_dataset[i]
|
||||
|
|
@ -79,28 +81,28 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
|||
def test_sft_dataset_with_random_data(base_test_env):
|
||||
"""Test SFT dataset with random data"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
|
||||
# Create SFT-style data with memory mapping format
|
||||
seq_length = np.random.randint(100, 200)
|
||||
|
||||
|
||||
dummy_data = {
|
||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)]
|
||||
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
||||
}
|
||||
|
||||
|
||||
save_h5(test_dir, "sft_data", dummy_data)
|
||||
|
||||
|
||||
# Load SFT dataset
|
||||
sft_dataset = DatasetLoader.load(
|
||||
train_type="sft",
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
train_type="sft",
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
)
|
||||
|
||||
|
||||
assert sft_dataset is not None
|
||||
assert hasattr(sft_dataset, 'fetcher')
|
||||
assert hasattr(sft_dataset, "fetcher")
|
||||
assert len(sft_dataset) > 0
|
||||
|
||||
|
||||
# Test that we can get SFT items without errors
|
||||
for i in range(min(3, len(sft_dataset))):
|
||||
item = sft_dataset[i]
|
||||
|
|
@ -114,33 +116,30 @@ def test_sft_dataset_with_random_data(base_test_env):
|
|||
def test_dataset_with_custom_stride(base_test_env):
|
||||
"""Test dataset with custom stride parameter"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
|
||||
# Create test data
|
||||
seq_length = 200
|
||||
dummy_data = {
|
||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||
}
|
||||
|
||||
save_h5(test_dir,"stride_test_data", dummy_data)
|
||||
|
||||
|
||||
save_h5(test_dir, "stride_test_data", dummy_data)
|
||||
|
||||
# Test with custom stride
|
||||
custom_stride = 32
|
||||
dataset = DatasetLoader.load(
|
||||
train_type="seq",
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
stride=custom_stride
|
||||
train_type="seq", load_path=test_dir, window_size=64, stride=custom_stride
|
||||
)
|
||||
|
||||
|
||||
assert dataset is not None
|
||||
assert len(dataset) > 0
|
||||
|
||||
|
||||
# With stride 32 and window 64 on 200 length data, we should get more samples
|
||||
# than with default stride (which equals window size)
|
||||
default_stride_dataset = DatasetLoader.load(
|
||||
train_type="seq",
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
train_type="seq",
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
)
|
||||
|
||||
|
||||
assert len(dataset) > len(default_stride_dataset)
|
||||
|
|
|
|||
|
|
@ -1,30 +1,32 @@
|
|||
from khaosz.trainer import *
|
||||
from khaosz.data import *
|
||||
|
||||
|
||||
def test_random_sampler_consistency(random_dataset):
|
||||
"""Test RandomSampler produces consistent results with same seed"""
|
||||
dataset = random_dataset
|
||||
|
||||
|
||||
# Create two samplers with same seed
|
||||
sampler1 = ResumableDistributedSampler(dataset, seed=42)
|
||||
sampler2 = ResumableDistributedSampler(dataset, seed=42)
|
||||
|
||||
|
||||
indices1 = list(iter(sampler1))
|
||||
indices2 = list(iter(sampler2))
|
||||
|
||||
|
||||
assert indices1 == indices2
|
||||
|
||||
|
||||
def test_random_sampler_different_seeds(random_dataset):
|
||||
"""Test RandomSampler produces different results with different seeds"""
|
||||
dataset = random_dataset
|
||||
|
||||
|
||||
# Create two samplers with different seeds
|
||||
sampler1 = ResumableDistributedSampler(dataset, seed=42)
|
||||
sampler2 = ResumableDistributedSampler(dataset, seed=123)
|
||||
|
||||
|
||||
indices1 = list(iter(sampler1))
|
||||
indices2 = list(iter(sampler2))
|
||||
|
||||
|
||||
# Very high probability they should be different
|
||||
assert indices1 != indices2
|
||||
|
||||
|
|
@ -33,20 +35,20 @@ def test_sampler_across_epochs(random_dataset):
|
|||
"""Test sampler behavior across multiple epochs"""
|
||||
dataset = random_dataset
|
||||
n = len(dataset)
|
||||
|
||||
|
||||
sampler = ResumableDistributedSampler(dataset, seed=42)
|
||||
|
||||
|
||||
# Get indices for first epoch
|
||||
epoch1_indices = list(iter(sampler))
|
||||
assert len(epoch1_indices) == n
|
||||
|
||||
|
||||
# Get indices for second epoch
|
||||
epoch2_indices = list(iter(sampler))
|
||||
assert len(epoch2_indices) == n
|
||||
|
||||
|
||||
# Check that epochs have different order (should be random)
|
||||
assert epoch1_indices != epoch2_indices
|
||||
|
||||
|
||||
# Check that all indices are present in each epoch
|
||||
assert set(epoch1_indices) == set(range(n))
|
||||
assert set(epoch2_indices) == set(range(n))
|
||||
assert set(epoch2_indices) == set(range(n))
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from khaosz.data import *
|
|||
from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore
|
||||
from tokenizers import pre_tokenizers
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_env(request: pytest.FixtureRequest):
|
||||
func_name = request.function.__name__
|
||||
|
|
@ -19,7 +20,7 @@ def test_env(request: pytest.FixtureRequest):
|
|||
config_path = os.path.join(test_dir, "config.json")
|
||||
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
|
||||
model_path = os.path.join(test_dir, "model.safetensors")
|
||||
|
||||
|
||||
config = {
|
||||
"vocab_size": 1000,
|
||||
"dim": 128,
|
||||
|
|
@ -28,20 +29,20 @@ def test_env(request: pytest.FixtureRequest):
|
|||
"dim_ffn": 256,
|
||||
"max_len": 64,
|
||||
"n_layers": 2,
|
||||
"norm_eps": 1e-5
|
||||
"norm_eps": 1e-5,
|
||||
}
|
||||
with open(config_path, 'w') as f:
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
|
||||
tokenizer = BpeTokenizer()
|
||||
sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
|
||||
tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
|
||||
tokenizer.save(tokenizer_path)
|
||||
|
||||
|
||||
transformer_config = ModelConfig().load(config_path)
|
||||
model = Transformer(transformer_config)
|
||||
st.save_file(model.state_dict(), model_path)
|
||||
|
||||
|
||||
yield {
|
||||
"test_dir": test_dir,
|
||||
"model": model,
|
||||
|
|
@ -51,47 +52,55 @@ def test_env(request: pytest.FixtureRequest):
|
|||
|
||||
shutil.rmtree(test_dir)
|
||||
|
||||
|
||||
def test_model_parameter(test_env):
|
||||
save_dir = os.path.join(test_env["test_dir"], "save")
|
||||
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])
|
||||
model_param = ModelParameter(
|
||||
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
|
||||
)
|
||||
ModelParameter.save(model_param, save_dir)
|
||||
|
||||
|
||||
assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
|
||||
assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))
|
||||
assert os.path.exists(os.path.join(save_dir, "config.json"))
|
||||
|
||||
|
||||
# transformer
|
||||
def test_transformer(test_env):
|
||||
model = test_env["model"]
|
||||
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size,
|
||||
(4, test_env["transformer_config"].max_len))
|
||||
input_ids = torch.randint(
|
||||
0,
|
||||
test_env["transformer_config"].vocab_size,
|
||||
(4, test_env["transformer_config"].max_len),
|
||||
)
|
||||
output_logits = model(input_ids)["logits"]
|
||||
target_shape = (4, test_env["transformer_config"].max_len, test_env["transformer_config"].vocab_size)
|
||||
target_shape = (
|
||||
4,
|
||||
test_env["transformer_config"].max_len,
|
||||
test_env["transformer_config"].vocab_size,
|
||||
)
|
||||
assert output_logits.shape == target_shape
|
||||
|
||||
|
||||
|
||||
# generator
|
||||
def test_embedding_encoder_core(test_env):
|
||||
parameter = ModelParameter(
|
||||
test_env["model"],
|
||||
test_env["tokenizer"],
|
||||
test_env["transformer_config"]
|
||||
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
|
||||
)
|
||||
encoder = EmbeddingEncoderCore(parameter)
|
||||
|
||||
|
||||
single_emb = encoder.encode("测试文本")
|
||||
assert isinstance(single_emb, torch.Tensor)
|
||||
assert single_emb.shape[-1] == test_env["transformer_config"].dim
|
||||
|
||||
|
||||
batch_emb = encoder.encode(["测试1", "测试2"])
|
||||
assert isinstance(batch_emb, list)
|
||||
assert len(batch_emb) == 2
|
||||
|
||||
|
||||
|
||||
def test_generator_core(test_env):
|
||||
parameter = ModelParameter(
|
||||
test_env["model"],
|
||||
test_env["tokenizer"],
|
||||
test_env["transformer_config"]
|
||||
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
|
||||
)
|
||||
generator = GeneratorCore(parameter)
|
||||
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10))
|
||||
|
|
@ -102,8 +111,8 @@ def test_generator_core(test_env):
|
|||
top_p=0.95,
|
||||
attn_mask=None,
|
||||
kv_caches=None,
|
||||
start_pos=0
|
||||
start_pos=0,
|
||||
)
|
||||
|
||||
|
||||
assert next_token_id.shape == (4, 1)
|
||||
assert cache_increase == 10
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ def transformer_test_env():
|
|||
"""创建Transformer测试专用环境"""
|
||||
test_dir = tempfile.mkdtemp(prefix="transformer_test_")
|
||||
config_path = os.path.join(test_dir, "config.json")
|
||||
|
||||
|
||||
config = {
|
||||
"vocab_size": 1000,
|
||||
"dim": 128,
|
||||
|
|
@ -22,18 +22,14 @@ def transformer_test_env():
|
|||
"dim_ffn": 256,
|
||||
"max_len": 64,
|
||||
"n_layers": 2,
|
||||
"norm_eps": 1e-5
|
||||
"norm_eps": 1e-5,
|
||||
}
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
yield {
|
||||
"test_dir": test_dir,
|
||||
"config_path": config_path,
|
||||
"config": config
|
||||
}
|
||||
|
||||
|
||||
yield {"test_dir": test_dir, "config_path": config_path, "config": config}
|
||||
|
||||
if os.path.exists(test_dir):
|
||||
try:
|
||||
for file in os.listdir(test_dir):
|
||||
|
|
@ -46,74 +42,75 @@ def transformer_test_env():
|
|||
def test_tie_weight_init(transformer_test_env):
|
||||
config_path = transformer_test_env["config_path"]
|
||||
config_data = transformer_test_env["config"].copy()
|
||||
|
||||
|
||||
# case 1: tie weight
|
||||
config_data["tie_weight"] = True
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
|
||||
config = ModelConfig().load(config_path)
|
||||
model = Transformer(config)
|
||||
|
||||
|
||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
||||
|
||||
|
||||
original_weight = model.embed_tokens.weight.clone()
|
||||
model.embed_tokens.weight.data[0, 0] = 100.0
|
||||
|
||||
|
||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
assert not torch.equal(model.lm_head.weight, original_weight)
|
||||
|
||||
|
||||
# case 2: not tie weight
|
||||
config_data["tie_weight"] = False
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
|
||||
config = ModelConfig().load(config_path)
|
||||
model = Transformer(config)
|
||||
|
||||
|
||||
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
||||
|
||||
original_weight = model.embed_tokens.weight.clone()
|
||||
model.embed_tokens.weight.data[0, 0] = 100.0
|
||||
|
||||
|
||||
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
assert not torch.equal(model.lm_head.weight, original_weight)
|
||||
|
||||
|
||||
def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||
test_dir = transformer_test_env["test_dir"]
|
||||
model_path = os.path.join(test_dir, "model.safetensors")
|
||||
|
||||
|
||||
config_data = transformer_test_env["config"].copy()
|
||||
|
||||
|
||||
# case 1: tie weight
|
||||
config_data["tie_weight"] = True
|
||||
config_path = os.path.join(test_dir, "config.json")
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
|
||||
config = ModelConfig().load(config_path)
|
||||
original_model = Transformer(config)
|
||||
|
||||
|
||||
st.save_file(original_model.state_dict(), model_path)
|
||||
|
||||
loaded_config = ModelConfig().load(config_path)
|
||||
model = Transformer(loaded_config)
|
||||
model.load_state_dict(st.load_file(model_path))
|
||||
|
||||
|
||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
||||
assert "lm_head.weight" not in model.state_dict()
|
||||
|
||||
# case 2: not tie weight (form tie-weight state dict load)
|
||||
config_data["tie_weight"] = False
|
||||
with open(config_path, 'w') as f:
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
|
||||
loaded_config = ModelConfig().load(config_path)
|
||||
model = Transformer(loaded_config)
|
||||
model.load_state_dict(st.load_file(model_path))
|
||||
|
|
@ -121,4 +118,3 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
|||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
||||
assert "lm_head.weight" in model.state_dict()
|
||||
|
||||
|
|
@ -1,16 +1,14 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from khaosz.parallel import (
|
||||
get_rank,
|
||||
only_on_rank,
|
||||
spawn_parallel_fn
|
||||
)
|
||||
from khaosz.parallel import get_rank, only_on_rank, spawn_parallel_fn
|
||||
|
||||
|
||||
@only_on_rank(0)
|
||||
def _test_only_on_rank_helper():
|
||||
return True
|
||||
|
||||
|
||||
def only_on_rank():
|
||||
result = _test_only_on_rank_helper()
|
||||
if get_rank() == 0:
|
||||
|
|
@ -18,22 +16,17 @@ def only_on_rank():
|
|||
else:
|
||||
assert result is None
|
||||
|
||||
|
||||
def all_reduce():
|
||||
x = torch.tensor([get_rank()], dtype=torch.int)
|
||||
dist.all_reduce(x, op=dist.ReduceOp.SUM)
|
||||
expected_sum = sum(range(dist.get_world_size()))
|
||||
assert x.item() == expected_sum
|
||||
|
||||
|
||||
def test_spawn_only_on_rank():
|
||||
spawn_parallel_fn(
|
||||
only_on_rank,
|
||||
world_size=2,
|
||||
backend="gloo"
|
||||
)
|
||||
spawn_parallel_fn(only_on_rank, world_size=2, backend="gloo")
|
||||
|
||||
|
||||
def test_spawn_all_reduce():
|
||||
spawn_parallel_fn(
|
||||
all_reduce,
|
||||
world_size=2,
|
||||
backend="gloo"
|
||||
)
|
||||
spawn_parallel_fn(all_reduce, world_size=2, backend="gloo")
|
||||
|
|
|
|||
|
|
@ -3,57 +3,48 @@ import torch
|
|||
from khaosz.config import *
|
||||
from khaosz.trainer import *
|
||||
|
||||
|
||||
def test_callback_integration(base_test_env, random_dataset):
|
||||
"""Test that all callbacks are properly integrated"""
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=10,
|
||||
total_steps=20
|
||||
)
|
||||
|
||||
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||
|
||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||
|
||||
|
||||
train_config = TrainConfig(
|
||||
model=base_test_env["model"],
|
||||
strategy='seq',
|
||||
strategy="seq",
|
||||
dataset=random_dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
checkpoint_dir=base_test_env["test_dir"],
|
||||
ckpt_dir=base_test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
batch_size=2,
|
||||
checkpoint_interval=3,
|
||||
ckpt_interval=3,
|
||||
accumulation_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
random_seed=42,
|
||||
device_type=base_test_env["device"]
|
||||
device_type=base_test_env["device"],
|
||||
)
|
||||
|
||||
|
||||
|
||||
# Create custom callbacks to track calls
|
||||
callback_calls = []
|
||||
|
||||
|
||||
class TrackingCallback(TrainCallback):
|
||||
def on_train_begin(self, context):
|
||||
callback_calls.append('on_train_begin')
|
||||
|
||||
def on_batch_end(self, context):
|
||||
callback_calls.append('on_batch_end')
|
||||
|
||||
def on_epoch_end(self, context):
|
||||
callback_calls.append('on_epoch_end')
|
||||
|
||||
callback_calls.append("on_train_begin")
|
||||
|
||||
def on_batch_end(self, context):
|
||||
callback_calls.append("on_batch_end")
|
||||
|
||||
def on_epoch_end(self, context):
|
||||
callback_calls.append("on_epoch_end")
|
||||
|
||||
trainer = Trainer(train_config, callbacks=[TrackingCallback()])
|
||||
|
||||
|
||||
trainer = Trainer(
|
||||
train_config,
|
||||
callbacks=[TrackingCallback()]
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
|
||||
# Verify callbacks were called
|
||||
assert 'on_train_begin' in callback_calls
|
||||
assert 'on_batch_end' in callback_calls
|
||||
assert 'on_epoch_end' in callback_calls
|
||||
assert "on_train_begin" in callback_calls
|
||||
assert "on_batch_end" in callback_calls
|
||||
assert "on_epoch_end" in callback_calls
|
||||
|
|
|
|||
|
|
@ -5,31 +5,32 @@ from khaosz.config import *
|
|||
from khaosz.trainer import *
|
||||
from khaosz.data.serialization import Checkpoint
|
||||
|
||||
|
||||
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||
"""Simulate early stopping behavior"""
|
||||
|
||||
|
||||
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||
|
||||
|
||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||
|
||||
|
||||
train_config = TrainConfig(
|
||||
strategy="seq",
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
model=base_test_env["model"],
|
||||
dataset=early_stopping_dataset,
|
||||
checkpoint_dir=base_test_env["test_dir"],
|
||||
ckpt_dir=base_test_env["test_dir"],
|
||||
n_epoch=2,
|
||||
batch_size=2,
|
||||
checkpoint_interval=1,
|
||||
ckpt_interval=1,
|
||||
accumulation_steps=2,
|
||||
random_seed=np.random.randint(1e4),
|
||||
device_type=base_test_env["device"]
|
||||
device_type=base_test_env["device"],
|
||||
)
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
|
||||
|
||||
# Should handle early stopping gracefully
|
||||
checkpoint = None
|
||||
try:
|
||||
|
|
@ -37,11 +38,11 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
|||
except Exception:
|
||||
# Handle any exceptions
|
||||
pass
|
||||
|
||||
|
||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
|
||||
checkpoint = Checkpoint.load(load_dir)
|
||||
trainer.train(checkpoint)
|
||||
|
||||
|
||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
|
||||
checkpoint = Checkpoint.load(load_dir)
|
||||
assert checkpoint.iteration == 10
|
||||
assert checkpoint.iteration == 10
|
||||
|
|
|
|||
|
|
@ -9,39 +9,41 @@ from khaosz.data.dataset import *
|
|||
|
||||
def test_schedule_factory_random_configs():
|
||||
"""Test scheduler factory with random configurations"""
|
||||
|
||||
|
||||
# Create a simple model and optimizer for testing
|
||||
model = torch.nn.Linear(10, 2)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
||||
|
||||
|
||||
# Test multiple random configurations
|
||||
for _ in range(5): # Test 5 random configurations
|
||||
schedule_configs = [
|
||||
CosineScheduleConfig(
|
||||
warmup_steps=np.random.randint(50, 200),
|
||||
total_steps=np.random.randint(1000, 5000),
|
||||
min_rate=np.random.uniform(0.01, 0.1)
|
||||
min_rate=np.random.uniform(0.01, 0.1),
|
||||
),
|
||||
SGDRScheduleConfig(
|
||||
warmup_steps=np.random.randint(50, 200),
|
||||
cycle_length=np.random.randint(500, 2000),
|
||||
t_mult=np.random.randint(1, 3),
|
||||
min_rate=np.random.uniform(0.01, 0.1)
|
||||
)
|
||||
min_rate=np.random.uniform(0.01, 0.1),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
for config in schedule_configs:
|
||||
# Validate configuration
|
||||
config.validate()
|
||||
|
||||
|
||||
# Create scheduler using factory
|
||||
scheduler = SchedulerFactory.load(optimizer, config)
|
||||
|
||||
|
||||
# Verify scheduler type
|
||||
if isinstance(config, CosineScheduleConfig):
|
||||
assert isinstance(scheduler, CosineScheduler)
|
||||
assert scheduler.warmup_steps == config.warmup_steps
|
||||
assert scheduler.lr_decay_steps == config.total_steps - config.warmup_steps
|
||||
assert (
|
||||
scheduler.lr_decay_steps == config.total_steps - config.warmup_steps
|
||||
)
|
||||
assert scheduler.min_rate == config.min_rate
|
||||
elif isinstance(config, SGDRScheduleConfig):
|
||||
assert isinstance(scheduler, SGDRScheduler)
|
||||
|
|
@ -49,17 +51,17 @@ def test_schedule_factory_random_configs():
|
|||
assert scheduler.cycle_length == config.cycle_length
|
||||
assert scheduler.t_mult == config.t_mult
|
||||
assert scheduler.min_rate == config.min_rate
|
||||
|
||||
|
||||
# Test scheduler state dict functionality
|
||||
state_dict = scheduler.state_dict()
|
||||
assert 'warmup_steps' in state_dict
|
||||
assert 'min_rate' in state_dict
|
||||
|
||||
assert "warmup_steps" in state_dict
|
||||
assert "min_rate" in state_dict
|
||||
|
||||
# Test scheduler step functionality
|
||||
initial_lr = scheduler.get_last_lr()
|
||||
scheduler.step()
|
||||
new_lr = scheduler.get_last_lr()
|
||||
|
||||
|
||||
# Learning rate should change after step, or if it's the first step,
|
||||
# the epoch counter should increment
|
||||
assert initial_lr != new_lr or scheduler.last_epoch > -1
|
||||
|
|
@ -67,10 +69,10 @@ def test_schedule_factory_random_configs():
|
|||
|
||||
def test_schedule_factory_edge_cases():
|
||||
"""Test scheduler factory with edge cases and boundary conditions"""
|
||||
|
||||
|
||||
model = torch.nn.Linear(10, 2)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
||||
|
||||
|
||||
# Test edge cases for CosineScheduleConfig
|
||||
edge_cases = [
|
||||
# Minimal warmup and steps
|
||||
|
|
@ -80,12 +82,12 @@ def test_schedule_factory_edge_cases():
|
|||
# Zero min_rate (edge case)
|
||||
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.0),
|
||||
]
|
||||
|
||||
|
||||
for config in edge_cases:
|
||||
config.validate()
|
||||
scheduler = SchedulerFactory.load(optimizer, config)
|
||||
assert scheduler is not None
|
||||
|
||||
|
||||
# Test multiple steps
|
||||
for _ in range(10):
|
||||
scheduler.step()
|
||||
|
|
@ -93,7 +95,7 @@ def test_schedule_factory_edge_cases():
|
|||
|
||||
def test_schedule_factory_invalid_configs():
|
||||
"""Test scheduler factory with invalid configurations"""
|
||||
|
||||
|
||||
# Test invalid configurations that should raise errors
|
||||
invalid_configs = [
|
||||
# Negative warmup steps
|
||||
|
|
@ -104,7 +106,7 @@ def test_schedule_factory_invalid_configs():
|
|||
{"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1},
|
||||
{"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1},
|
||||
]
|
||||
|
||||
|
||||
for kwargs in invalid_configs:
|
||||
with pytest.raises(ValueError):
|
||||
config = CosineScheduleConfig(**kwargs)
|
||||
|
|
@ -113,24 +115,24 @@ def test_schedule_factory_invalid_configs():
|
|||
|
||||
def test_schedule_factory_state_persistence():
|
||||
"""Test scheduler state persistence (save/load)"""
|
||||
|
||||
|
||||
model = torch.nn.Linear(10, 2)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
||||
|
||||
|
||||
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1)
|
||||
scheduler = SchedulerFactory.load(optimizer, config)
|
||||
|
||||
|
||||
# Take a few steps
|
||||
for _ in range(5):
|
||||
scheduler.step()
|
||||
|
||||
|
||||
# Save state
|
||||
state_dict = scheduler.state_dict()
|
||||
|
||||
|
||||
# Create new scheduler and load state
|
||||
new_scheduler = SchedulerFactory.load(optimizer, config)
|
||||
new_scheduler.load_state_dict(state_dict)
|
||||
|
||||
|
||||
# Verify states match
|
||||
assert scheduler.last_epoch == new_scheduler.last_epoch
|
||||
assert scheduler.get_last_lr() == new_scheduler.get_last_lr()
|
||||
assert scheduler.get_last_lr() == new_scheduler.get_last_lr()
|
||||
|
|
|
|||
|
|
@ -6,100 +6,94 @@ from khaosz.config import *
|
|||
from khaosz.trainer import *
|
||||
from khaosz.data.dataset import *
|
||||
|
||||
|
||||
def test_different_batch_sizes(base_test_env, random_dataset):
|
||||
"""Test training with different batch sizes"""
|
||||
batch_sizes = [1, 2, 4, 8]
|
||||
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=10,
|
||||
total_steps=20
|
||||
)
|
||||
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||
|
||||
|
||||
train_config = TrainConfig(
|
||||
strategy="seq",
|
||||
model=base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
checkpoint_dir=base_test_env["test_dir"],
|
||||
ckpt_dir=base_test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
batch_size=batch_size,
|
||||
checkpoint_interval=5,
|
||||
ckpt_interval=5,
|
||||
accumulation_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
random_seed=np.random.randint(1000),
|
||||
device_type=base_test_env["device"]
|
||||
device_type=base_test_env["device"],
|
||||
)
|
||||
|
||||
|
||||
assert train_config.batch_size == batch_size
|
||||
|
||||
|
||||
def test_gradient_accumulation(base_test_env, random_dataset):
|
||||
"""Test training with different gradient accumulation steps"""
|
||||
accumulation_steps_list = [1, 2, 4]
|
||||
|
||||
|
||||
for accumulation_steps in accumulation_steps_list:
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=10,
|
||||
total_steps=20
|
||||
)
|
||||
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||
|
||||
|
||||
train_config = TrainConfig(
|
||||
strategy="seq",
|
||||
model=base_test_env["model"],
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
dataset=random_dataset,
|
||||
checkpoint_dir=base_test_env["test_dir"],
|
||||
ckpt_dir=base_test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
batch_size=2,
|
||||
checkpoint_interval=10,
|
||||
ckpt_interval=10,
|
||||
accumulation_steps=accumulation_steps,
|
||||
max_grad_norm=1.0,
|
||||
random_seed=42,
|
||||
device_type=base_test_env["device"]
|
||||
device_type=base_test_env["device"],
|
||||
)
|
||||
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
trainer.train()
|
||||
|
||||
|
||||
assert train_config.accumulation_steps == accumulation_steps
|
||||
|
||||
|
||||
def test_memory_efficient_training(base_test_env, random_dataset):
|
||||
"""Test training with memory-efficient configurations"""
|
||||
# Test with smaller batch sizes and gradient checkpointing
|
||||
small_batch_configs = [
|
||||
{"batch_size": 1, "accumulation_steps": 8},
|
||||
{"batch_size": 2, "accumulation_steps": 4},
|
||||
{"batch_size": 4, "accumulation_steps": 2}
|
||||
{"batch_size": 4, "accumulation_steps": 2},
|
||||
]
|
||||
|
||||
|
||||
for config in small_batch_configs:
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=10,
|
||||
total_steps=20
|
||||
)
|
||||
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||
|
||||
|
||||
train_config = TrainConfig(
|
||||
strategy="seq",
|
||||
model=base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
checkpoint_dir=base_test_env["test_dir"],
|
||||
ckpt_dir=base_test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
batch_size=config["batch_size"],
|
||||
checkpoint_interval=5,
|
||||
ckpt_interval=5,
|
||||
accumulation_steps=config["accumulation_steps"],
|
||||
max_grad_norm=1.0,
|
||||
random_seed=42,
|
||||
device_type=base_test_env["device"]
|
||||
device_type=base_test_env["device"],
|
||||
)
|
||||
|
||||
assert train_config.accumulation_steps == config["accumulation_steps"]
|
||||
|
||||
assert train_config.accumulation_steps == config["accumulation_steps"]
|
||||
|
|
|
|||
|
|
@ -17,41 +17,47 @@ class GenerationBenchmark:
|
|||
self,
|
||||
config: ModelConfig,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.float16
|
||||
dtype: torch.dtype = torch.float16,
|
||||
):
|
||||
self.config = config
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.model = Transformer(config).to(device=device, dtype=dtype)
|
||||
self.model.eval()
|
||||
|
||||
|
||||
def _initialize_kv_cache(self, batch_size: int) -> list:
|
||||
"""初始化KV缓存"""
|
||||
config = self.config
|
||||
shape = (batch_size, config.max_len, config.n_layers, config.n_kv_heads, config.dim // config.n_heads)
|
||||
shape = (
|
||||
batch_size,
|
||||
config.max_len,
|
||||
config.n_layers,
|
||||
config.n_kv_heads,
|
||||
config.dim // config.n_heads,
|
||||
)
|
||||
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
||||
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
||||
return (k_cache, v_cache)
|
||||
|
||||
|
||||
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
|
||||
prompt_ids = torch.randint(
|
||||
low=0,
|
||||
high=self.config.vocab_size,
|
||||
size=(batch_size, prompt_length),
|
||||
device=self.device,
|
||||
dtype=torch.long
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
|
||||
gen_ids = torch.randint(
|
||||
low=0,
|
||||
high=self.config.vocab_size,
|
||||
size=(batch_size, total_length - prompt_length),
|
||||
device=self.device,
|
||||
dtype=torch.long
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
|
||||
return prompt_ids, gen_ids
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_prefill_benchmark(
|
||||
self,
|
||||
|
|
@ -59,32 +65,38 @@ class GenerationBenchmark:
|
|||
prompt_length: int = 512,
|
||||
num_trials: int = 10,
|
||||
) -> BenchmarkResult:
|
||||
|
||||
|
||||
for _ in range(3):
|
||||
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length)
|
||||
prompt_ids, _ = self._prepare_inputs(
|
||||
batch_size, prompt_length, prompt_length
|
||||
)
|
||||
_ = self.model(prompt_ids)
|
||||
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
total_time = 0.0
|
||||
total_tokens = batch_size * prompt_length * num_trials
|
||||
|
||||
|
||||
for trial in range(num_trials):
|
||||
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length)
|
||||
prompt_ids, _ = self._prepare_inputs(
|
||||
batch_size, prompt_length, prompt_length
|
||||
)
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
|
||||
start_event.record()
|
||||
_ = self.model(prompt_ids)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
trial_time = start_event.elapsed_time(end_event) / 1000
|
||||
total_time += trial_time
|
||||
|
||||
print(f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
|
||||
f"({prompt_length / trial_time:.1f} tokens/s)")
|
||||
|
||||
|
||||
print(
|
||||
f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
|
||||
f"({prompt_length / trial_time:.1f} tokens/s)"
|
||||
)
|
||||
|
||||
return BenchmarkResult(
|
||||
total_tokens=total_tokens,
|
||||
total_time=total_time,
|
||||
|
|
@ -95,9 +107,9 @@ class GenerationBenchmark:
|
|||
"prompt_length": prompt_length,
|
||||
"dtype": self.dtype,
|
||||
"device": self.device,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_decoding_benchmark(
|
||||
self,
|
||||
|
|
@ -106,39 +118,43 @@ class GenerationBenchmark:
|
|||
gen_length: int = 128,
|
||||
num_trials: int = 5,
|
||||
) -> BenchmarkResult:
|
||||
|
||||
|
||||
total_time = 0.0
|
||||
total_tokens = batch_size * gen_length * num_trials
|
||||
|
||||
|
||||
for trial in range(num_trials):
|
||||
|
||||
prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length)
|
||||
prompt_ids, gen_ids = self._prepare_inputs(
|
||||
batch_size, prompt_length, prompt_length + gen_length
|
||||
)
|
||||
kv_cache = self._initialize_kv_cache(batch_size)
|
||||
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
|
||||
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
|
||||
start_event.record()
|
||||
|
||||
|
||||
current_pos = prompt_length
|
||||
for i in range(gen_length):
|
||||
input_token = gen_ids[:, i:i+1]
|
||||
_ = self.model(input_token, persistent_key_values=kv_cache, start_pos=current_pos)
|
||||
input_token = gen_ids[:, i : i + 1]
|
||||
_ = self.model(
|
||||
input_token, persistent_key_values=kv_cache, start_pos=current_pos
|
||||
)
|
||||
current_pos += 1
|
||||
|
||||
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
trial_time = start_event.elapsed_time(end_event) / 1000
|
||||
total_time += trial_time
|
||||
|
||||
print(f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
||||
f"({gen_length / trial_time:.1f} tokens/s)")
|
||||
|
||||
|
||||
print(
|
||||
f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
||||
f"({gen_length / trial_time:.1f} tokens/s)"
|
||||
)
|
||||
|
||||
return BenchmarkResult(
|
||||
total_tokens=total_tokens,
|
||||
total_time=total_time,
|
||||
|
|
@ -150,24 +166,28 @@ class GenerationBenchmark:
|
|||
"gen_length": gen_length,
|
||||
"dtype": self.dtype,
|
||||
"device": self.device,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def print_benchmark_result(result: BenchmarkResult):
|
||||
"""打印基准测试结果"""
|
||||
benchmark_type = result.metadata["benchmark_type"]
|
||||
|
||||
|
||||
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
|
||||
print(f"Total Tokens Processed: {result.total_tokens:,}")
|
||||
print(f"Time Consumed: {result.total_time:.3f}s")
|
||||
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
|
||||
|
||||
|
||||
if benchmark_type == "prefill":
|
||||
print(f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}")
|
||||
print(
|
||||
f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}"
|
||||
)
|
||||
elif benchmark_type == "decoding":
|
||||
print(f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}")
|
||||
|
||||
print(
|
||||
f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}"
|
||||
)
|
||||
|
||||
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
|
||||
print("-" * 80)
|
||||
|
||||
|
|
@ -183,16 +203,19 @@ if __name__ == "__main__":
|
|||
n_layers=24,
|
||||
norm_eps=1e-5,
|
||||
)
|
||||
|
||||
|
||||
benchmark = GenerationBenchmark(config)
|
||||
|
||||
|
||||
print("=" * 80)
|
||||
print("Running Transformer Generation Benchmark")
|
||||
print("=" * 80)
|
||||
|
||||
prefill_result = benchmark.run_prefill_benchmark(batch_size=4, prompt_length=512, num_trials=5)
|
||||
|
||||
prefill_result = benchmark.run_prefill_benchmark(
|
||||
batch_size=4, prompt_length=512, num_trials=5
|
||||
)
|
||||
print_benchmark_result(prefill_result)
|
||||
|
||||
gen_result = benchmark.run_decoding_benchmark(batch_size=4, prompt_length=512, gen_length=128, num_trials=5)
|
||||
|
||||
gen_result = benchmark.run_decoding_benchmark(
|
||||
batch_size=4, prompt_length=512, gen_length=128, num_trials=5
|
||||
)
|
||||
print_benchmark_result(gen_result)
|
||||
|
||||
|
|
@ -21,10 +21,10 @@ def processor(
|
|||
with disable_random_init():
|
||||
param = ModelParameter.load(model_dir)
|
||||
|
||||
param.to(device='cuda', dtype=torch.bfloat16)
|
||||
param.to(device="cuda", dtype=torch.bfloat16)
|
||||
generator = BatchGenerator(param)
|
||||
|
||||
with open(input_json_file, "r", encoding='utf-8') as f:
|
||||
with open(input_json_file, "r", encoding="utf-8") as f:
|
||||
input_data = [json.loads(line) for line in f]
|
||||
|
||||
queries = [item[question_key] for item in input_data]
|
||||
|
|
@ -41,26 +41,62 @@ def processor(
|
|||
|
||||
responses = generator.generate(request)
|
||||
|
||||
with open(output_json_file, "w", encoding='utf-8') as f:
|
||||
with open(output_json_file, "w", encoding="utf-8") as f:
|
||||
for query, response in zip(queries, responses):
|
||||
output_item = {question_key: query, response_key: response}
|
||||
f.write(json.dumps(output_item, ensure_ascii=False) + '\n')
|
||||
f.write(json.dumps(output_item, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
|
||||
|
||||
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
|
||||
parser.add_argument("--input_json_file", type=str, required=True, help="Path to the input JSONL file.")
|
||||
parser.add_argument("--output_json_file", type=str, required=True, help="Path to the output JSONL file.")
|
||||
parser.add_argument("--question_key", type=str, default="question", help="Key for the question in the input JSON.")
|
||||
parser.add_argument("--response_key", type=str, default="response", help="Key for the response in the output JSON.")
|
||||
parser.add_argument("--temperature", type=float, default=0.60, help="Temperature for generating responses.")
|
||||
parser.add_argument("--top_k", type=int, default=30, help="Top-k value for generating responses.")
|
||||
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.")
|
||||
|
||||
parser.add_argument(
|
||||
"--model_dir", type=str, required=True, help="Path to the model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_json_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the input JSONL file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_json_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output JSONL file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--question_key",
|
||||
type=str,
|
||||
default="question",
|
||||
help="Key for the question in the input JSON.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--response_key",
|
||||
type=str,
|
||||
default="response",
|
||||
help="Key for the response in the output JSON.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.60,
|
||||
help="Temperature for generating responses.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top_k", type=int, default=30, help="Top-k value for generating responses."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top_p",
|
||||
type=float,
|
||||
default=0.95,
|
||||
help="Top-p value for generating responses.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=1, help="Batch size for generating responses."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
with torch.inference_mode():
|
||||
processor(**vars(args))
|
||||
processor(**vars(args))
|
||||
|
|
|
|||
|
|
@ -11,89 +11,99 @@ from khaosz.inference.core import disable_random_init
|
|||
|
||||
|
||||
def compute_perplexity(
|
||||
model: nn.Module,
|
||||
input_ids: Tensor,
|
||||
input_mask: Tensor,
|
||||
) -> Tensor:
|
||||
model: nn.Module,
|
||||
input_ids: Tensor,
|
||||
input_mask: Tensor,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Compute the perplexity of a batch of input sequences,
|
||||
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
|
||||
Compute the perplexity of a batch of input sequences,
|
||||
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
|
||||
"""
|
||||
|
||||
|
||||
output = model(input_ids, input_mask)
|
||||
logits = output["logits"]
|
||||
|
||||
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
||||
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
|
||||
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
|
||||
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
||||
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
|
||||
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
|
||||
|
||||
loss = F.cross_entropy(
|
||||
shifted_logits.flatten(0, 1),
|
||||
shifted_input_ids.flatten(0, 1),
|
||||
reduction='none'
|
||||
shifted_logits.flatten(0, 1), shifted_input_ids.flatten(0, 1), reduction="none"
|
||||
)
|
||||
|
||||
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
|
||||
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
|
||||
loss = loss * shifted_mask
|
||||
sentence_loss = (loss).sum(dim=1) / shifted_mask.sum(dim=1)
|
||||
perplexity = torch.exp(sentence_loss) # [batch_size]
|
||||
|
||||
perplexity = torch.exp(sentence_loss) # [batch_size]
|
||||
|
||||
return perplexity
|
||||
|
||||
|
||||
def process_file(
|
||||
model_dir: str,
|
||||
input_file: str,
|
||||
output_file: str,
|
||||
batch_size: int,
|
||||
text_key: str
|
||||
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
||||
):
|
||||
with disable_random_init():
|
||||
param = ModelParameter.load(model_dir)
|
||||
|
||||
param.to(device='cuda', dtype=torch.bfloat16)
|
||||
param.to(device="cuda", dtype=torch.bfloat16)
|
||||
model = param.model
|
||||
tokenizer = param.tokenizer
|
||||
|
||||
with open(input_file, "r", encoding='utf-8') as f:
|
||||
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
input_data = [json.loads(line) for line in f]
|
||||
|
||||
|
||||
texts = [item[text_key] for item in input_data]
|
||||
encoded_texts = [tokenizer.encode(text) for text in texts]
|
||||
output_data = []
|
||||
|
||||
for i in tqdm(range(0, len(encoded_texts), batch_size), desc="Computing perplexity"):
|
||||
batch_encoded = encoded_texts[i:i + batch_size]
|
||||
batch_texts = texts[i:i + batch_size]
|
||||
|
||||
for i in tqdm(
|
||||
range(0, len(encoded_texts), batch_size), desc="Computing perplexity"
|
||||
):
|
||||
batch_encoded = encoded_texts[i : i + batch_size]
|
||||
batch_texts = texts[i : i + batch_size]
|
||||
max_len = max(len(seq) for seq in batch_encoded)
|
||||
padded_ids = []
|
||||
masks = []
|
||||
|
||||
|
||||
for seq in batch_encoded:
|
||||
pad_len = max_len - len(seq)
|
||||
padded_seq = [tokenizer.pad_id] * pad_len + seq
|
||||
mask = [False] * pad_len + [True] * len(seq)
|
||||
padded_ids.append(padded_seq)
|
||||
masks.append(mask)
|
||||
|
||||
|
||||
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
|
||||
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
|
||||
perplexity = compute_perplexity(model, input_ids, input_mask)
|
||||
|
||||
|
||||
for text, ppl in zip(batch_texts, perplexity):
|
||||
output_data.append({text_key: text, "ppl": float(ppl.item())})
|
||||
|
||||
with open(output_file, "w", encoding='utf-8') as f:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
for item in output_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
||||
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
|
||||
parser.add_argument("--input_file", type=str, required=True, help="Path to the input file.")
|
||||
parser.add_argument("--output_file", type=str, required=True, help="Path to the output file.")
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation.")
|
||||
parser.add_argument("--text_key", type=str, default="text", help="Key for the text field in the input data.")
|
||||
parser.add_argument(
|
||||
"--model_dir", type=str, required=True, help="Path to the model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_file", type=str, required=True, help="Path to the input file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_file", type=str, required=True, help="Path to the output file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=4, help="Batch size for evaluation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_key",
|
||||
type=str,
|
||||
default="text",
|
||||
help="Key for the text field in the input data.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
with torch.inference_mode():
|
||||
|
|
|
|||
205
tools/train.py
205
tools/train.py
|
|
@ -15,40 +15,130 @@ from khaosz.parallel import get_rank
|
|||
def parse_args() -> argparse.Namespace:
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||
|
||||
parser.add_argument("--train_type", type=str, required=True, choices=["seq", "sft", "dpo"], help="Train type.")
|
||||
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
|
||||
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
|
||||
|
||||
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.")
|
||||
parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.")
|
||||
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
|
||||
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
|
||||
parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.")
|
||||
parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.")
|
||||
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
|
||||
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
||||
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
|
||||
parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory")
|
||||
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.")
|
||||
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.")
|
||||
|
||||
parser.add_argument(
|
||||
"--train_type",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["seq", "sft", "dpo"],
|
||||
help="Train type.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_root_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the root directory of the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--param_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the model parameters or resume checkpoint.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=1, help="Batch size for training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of iterations between each optimizer step.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup_steps",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of iters between warnings.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_grad_norm",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Max gradient norm for clipping.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adamw_beta1",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="Beta values for AdamW optimizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adamw_beta2",
|
||||
type=float,
|
||||
default=0.95,
|
||||
help="Beta values for AdamW optimizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adamw_weight_decay",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="Weight decay for AdamW optimizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_seed", type=int, default=3407, help="Random seed for reproducibility."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, default=4, help="Number of workers for data loading."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_pin_memory",
|
||||
action="store_false",
|
||||
dest="pin_memory",
|
||||
help="Disable pin memory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--window_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="the max length of the input sequence.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stride", type=int, default=None, help="the step size of the input sequence."
|
||||
)
|
||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.1, help="cross_entropy function label smoothing parameter")
|
||||
|
||||
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
|
||||
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
|
||||
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
||||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
||||
|
||||
parser.add_argument(
|
||||
"--label_smoothing",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="cross_entropy function label smoothing parameter",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ckpt_interval",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="Number of iters between checkpoints.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ckpt_dir",
|
||||
type=str,
|
||||
default="checkpoint",
|
||||
help="Directory to save checkpoints.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start_epoch", type=int, default=0, help="Start epoch for training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start_batch", type=int, default=0, help="Start batch for training."
|
||||
)
|
||||
|
||||
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
||||
parser.add_argument("--device_type", type=str, default="cuda", help="Device type to use.")
|
||||
|
||||
parser.add_argument(
|
||||
"--device_type", type=str, default="cuda", help="Device type to use."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def ddp_wrap(model: nn.Module):
|
||||
local_rank = get_rank()
|
||||
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
|
||||
|
|
@ -56,16 +146,21 @@ def ddp_wrap(model: nn.Module):
|
|||
model,
|
||||
device_ids=[local_rank],
|
||||
output_device=local_rank,
|
||||
find_unused_parameters=False
|
||||
find_unused_parameters=False,
|
||||
)
|
||||
return ddp_model
|
||||
|
||||
|
||||
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
||||
return optim.AdamW(model.parameters(), **kwargs)
|
||||
|
||||
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
|
||||
|
||||
def create_scheduler(
|
||||
optimizer: optim.Optimizer, **kwargs
|
||||
) -> optim.lr_scheduler.LRScheduler:
|
||||
return SchedulerFactory.load(optimizer, **kwargs)
|
||||
|
||||
|
||||
def prepare_checkpoint(model: nn.Module) -> dict:
|
||||
return model.module.state_dict()
|
||||
|
||||
|
|
@ -81,8 +176,8 @@ def train(
|
|||
start_batch: int,
|
||||
accumulation_steps: int,
|
||||
warmup_steps: int,
|
||||
checkpoint_interval: int,
|
||||
checkpoint_dir: str,
|
||||
ckpt_interval: int,
|
||||
ckpt_dir: str,
|
||||
dpo_beta: float,
|
||||
adamw_beta1: float,
|
||||
adamw_beta2: float,
|
||||
|
|
@ -99,48 +194,50 @@ def train(
|
|||
):
|
||||
assert train_type in ["seq", "sft", "dpo"]
|
||||
assert os.path.exists(param_path)
|
||||
|
||||
|
||||
parameter = ModelParameter.load(param_path)
|
||||
|
||||
if window_size is None:
|
||||
window_size = parameter.config.max_len
|
||||
|
||||
model = parameter.model
|
||||
|
||||
strategy_kwargs = {
|
||||
"dpo_beta": dpo_beta,
|
||||
"label_smoothing": label_smoothing
|
||||
}
|
||||
|
||||
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
|
||||
|
||||
dataset = DatasetLoader.load(
|
||||
train_type=train_type,
|
||||
load_path=data_root_path,
|
||||
window_size=window_size,
|
||||
stride=stride
|
||||
)
|
||||
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=warmup_steps,
|
||||
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
|
||||
stride=stride,
|
||||
)
|
||||
|
||||
optimizer_fn = partial(create_optimizer,
|
||||
**{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay})
|
||||
scheduler_fn = partial(create_scheduler,
|
||||
**{"schedule_config": schedule_config})
|
||||
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=warmup_steps,
|
||||
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
|
||||
)
|
||||
|
||||
optimizer_fn = partial(
|
||||
create_optimizer,
|
||||
**{
|
||||
"lr": max_lr,
|
||||
"betas": (adamw_beta1, adamw_beta2),
|
||||
"weight_decay": adamw_weight_decay,
|
||||
},
|
||||
)
|
||||
scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config})
|
||||
|
||||
train_config = TrainConfig(
|
||||
model=model,
|
||||
strategy=train_type,
|
||||
dataset=dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
ckpt_dir=ckpt_dir,
|
||||
n_epoch=n_epoch,
|
||||
batch_size=batch_size,
|
||||
start_epoch=start_epoch,
|
||||
start_batch=start_batch,
|
||||
checkpoint_interval=checkpoint_interval,
|
||||
ckpt_interval=ckpt_interval,
|
||||
accumulation_steps=accumulation_steps,
|
||||
max_grad_norm=max_grad_norm,
|
||||
random_seed=random_seed,
|
||||
|
|
@ -152,11 +249,11 @@ def train(
|
|||
device_type=device_type,
|
||||
extra_kwargs=strategy_kwargs,
|
||||
)
|
||||
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
trainer.train()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
train(**vars(args))
|
||||
train(**vars(args))
|
||||
|
|
|
|||
Loading…
Reference in New Issue