style: 使用ruff 工具优化代码风格
This commit is contained in:
parent
345fd2f091
commit
426af2d75f
|
|
@ -0,0 +1,27 @@
|
||||||
|
name: Lint
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python 3.12
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install .[dev]
|
||||||
|
|
||||||
|
- name: Check formatting with ruff
|
||||||
|
run: |
|
||||||
|
ruff format --check .
|
||||||
|
|
@ -1,17 +0,0 @@
|
||||||
name: Spell Check
|
|
||||||
on: [push, pull_request]
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
spellcheck:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- name: Check spelling in specific files
|
|
||||||
uses: codespell-project/actions-codespell@v2
|
|
||||||
with:
|
|
||||||
check_filenames: true
|
|
||||||
only_warn: false
|
|
||||||
path: "**/*.{md, py}"
|
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
name: Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.12"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install .[dev]
|
||||||
|
|
||||||
|
- name: Run tests with pytest
|
||||||
|
run: |
|
||||||
|
python -m pytest tests/ -v
|
||||||
|
|
@ -8,5 +8,8 @@
|
||||||
!*.py
|
!*.py
|
||||||
!*.md
|
!*.md
|
||||||
!*.png
|
!*.png
|
||||||
|
|
||||||
!LICENSE
|
!LICENSE
|
||||||
!pyproject.toml
|
!pyproject.toml
|
||||||
|
!.github/workflows/lint.yml
|
||||||
|
!.github/workflows/tests.yml
|
||||||
16
README.md
16
README.md
|
|
@ -54,8 +54,8 @@ python train.py \
|
||||||
--n_epoch=5 \
|
--n_epoch=5 \
|
||||||
--batch_size=8 \
|
--batch_size=8 \
|
||||||
--max_lr=2e-4 \
|
--max_lr=2e-4 \
|
||||||
--checkpoint_interval=10000 \
|
--ckpt_interval=10000 \
|
||||||
--checkpoint_dir=checkpoints
|
--ckpt_dir=checkpoints
|
||||||
```
|
```
|
||||||
|
|
||||||
**Parameter Explanation:**
|
**Parameter Explanation:**
|
||||||
|
|
@ -67,8 +67,8 @@ python train.py \
|
||||||
- `--accumulation_steps`: Number of batches per training step
|
- `--accumulation_steps`: Number of batches per training step
|
||||||
- `--warmup_steps`: Warmup steps
|
- `--warmup_steps`: Warmup steps
|
||||||
- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
|
- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
|
||||||
- `--checkpoint_interval`: Checkpoint saving interval
|
- `--ckpt_interval`: Checkpoint saving interval
|
||||||
- `--checkpoint_dir`: Checkpoint saving directory
|
- `--ckpt_dir`: Checkpoint saving directory
|
||||||
- `--resume_dir`: Resume training from specified path
|
- `--resume_dir`: Resume training from specified path
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -191,8 +191,8 @@ python train.py \
|
||||||
--n_epoch=5 \
|
--n_epoch=5 \
|
||||||
--batch_size=8 \
|
--batch_size=8 \
|
||||||
--max_lr=2e-4 \
|
--max_lr=2e-4 \
|
||||||
--checkpoint_interval=10000 \
|
--ckpt_interval=10000 \
|
||||||
--checkpoint_dir=checkpoints
|
--ckpt_dir=checkpoints
|
||||||
```
|
```
|
||||||
|
|
||||||
**参数说明:**
|
**参数说明:**
|
||||||
|
|
@ -204,8 +204,8 @@ python train.py \
|
||||||
- `--accumulation_steps`: 每个训练步骤的 batch 数量
|
- `--accumulation_steps`: 每个训练步骤的 batch 数量
|
||||||
- `--warmup_steps`: 预热步数(warmup steps)
|
- `--warmup_steps`: 预热步数(warmup steps)
|
||||||
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
|
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
|
||||||
- `--checkpoint_interval`: 检查点保存间隔
|
- `--ckpt_interval`: 检查点保存间隔
|
||||||
- `--checkpoint_dir`: 检查点保存目录
|
- `--ckpt_dir`: 检查点保存目录
|
||||||
- `--resume_dir`: 从指定路径恢复训练
|
- `--resume_dir`: 从指定路径恢复训练
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,12 @@ import os
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id="ViperEk/KHAOSZ",
|
repo_id="ViperEk/KHAOSZ",
|
||||||
local_dir=os.path.join(PROJECT_ROOT, "params"),
|
local_dir=os.path.join(PROJECT_ROOT, "params"),
|
||||||
force_download=True
|
force_download=True,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,18 +5,18 @@ from khaosz.inference.core import disable_random_init
|
||||||
from khaosz.inference.generator import LoopGenerator, GenerationRequest
|
from khaosz.inference.generator import LoopGenerator, GenerationRequest
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
def generate_text():
|
def generate_text():
|
||||||
|
|
||||||
with disable_random_init():
|
with disable_random_init():
|
||||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
param = ModelParameter.load(model_dir)
|
param = ModelParameter.load(model_dir)
|
||||||
|
|
||||||
param.to(device='cuda', dtype=torch.bfloat16)
|
param.to(device="cuda", dtype=torch.bfloat16)
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
|
|
||||||
request = GenerationRequest(
|
request = GenerationRequest(
|
||||||
query=query,
|
query=query,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
|
|
@ -28,8 +28,9 @@ def generate_text():
|
||||||
)
|
)
|
||||||
generator = LoopGenerator(param)
|
generator = LoopGenerator(param)
|
||||||
response = generator.generate(request)
|
response = generator.generate(request)
|
||||||
|
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
generate_text()
|
generate_text()
|
||||||
|
|
|
||||||
|
|
@ -4,18 +4,24 @@ from khaosz.config.param_config import ModelParameter
|
||||||
from khaosz.inference.core import disable_random_init
|
from khaosz.inference.core import disable_random_init
|
||||||
from khaosz.inference.generator import BatchGenerator, GenerationRequest
|
from khaosz.inference.generator import BatchGenerator, GenerationRequest
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
def batch_generate():
|
def batch_generate():
|
||||||
with disable_random_init():
|
with disable_random_init():
|
||||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
param = ModelParameter.load(model_dir)
|
param = ModelParameter.load(model_dir)
|
||||||
|
|
||||||
param.to(device='cuda', dtype=torch.bfloat16)
|
param.to(device="cuda", dtype=torch.bfloat16)
|
||||||
generator = BatchGenerator(param)
|
generator = BatchGenerator(param)
|
||||||
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
|
inputs = [
|
||||||
|
"你好",
|
||||||
|
"请问什么是人工智能",
|
||||||
|
"今天天气如何",
|
||||||
|
"我感到焦虑, 请问我应该怎么办",
|
||||||
|
"请问什么是显卡",
|
||||||
|
]
|
||||||
|
|
||||||
request = GenerationRequest(
|
request = GenerationRequest(
|
||||||
query=inputs,
|
query=inputs,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
|
|
@ -26,9 +32,10 @@ def batch_generate():
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
)
|
)
|
||||||
responses = generator.generate(request)
|
responses = generator.generate(request)
|
||||||
|
|
||||||
for q, r in zip(inputs, responses):
|
for q, r in zip(inputs, responses):
|
||||||
print((q, r))
|
print((q, r))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
batch_generate()
|
batch_generate()
|
||||||
|
|
|
||||||
|
|
@ -5,16 +5,16 @@ from khaosz.inference.core import disable_random_init
|
||||||
from khaosz.inference.generator import StreamGenerator, GenerationRequest
|
from khaosz.inference.generator import StreamGenerator, GenerationRequest
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
def chat():
|
def chat():
|
||||||
|
|
||||||
with disable_random_init():
|
with disable_random_init():
|
||||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
param = ModelParameter.load(model_dir)
|
param = ModelParameter.load(model_dir)
|
||||||
|
|
||||||
param.to(device='cuda', dtype=torch.bfloat16)
|
param.to(device="cuda", dtype=torch.bfloat16)
|
||||||
generator = StreamGenerator(param)
|
generator = StreamGenerator(param)
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
|
|
@ -22,7 +22,7 @@ def chat():
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
if query == "!exit":
|
if query == "!exit":
|
||||||
break
|
break
|
||||||
|
|
||||||
request = GenerationRequest(
|
request = GenerationRequest(
|
||||||
query=query,
|
query=query,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
|
|
@ -32,7 +32,7 @@ def chat():
|
||||||
history=history,
|
history=history,
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
response_size = 0
|
response_size = 0
|
||||||
full_response = ""
|
full_response = ""
|
||||||
for response in generator.generate(request):
|
for response in generator.generate(request):
|
||||||
|
|
@ -40,10 +40,10 @@ def chat():
|
||||||
print(response[response_size:], end="", flush=True)
|
print(response[response_size:], end="", flush=True)
|
||||||
response_size = len(response)
|
response_size = len(response)
|
||||||
full_response = response
|
full_response = response
|
||||||
|
|
||||||
# After generation, update history
|
# After generation, update history
|
||||||
history.append((query, full_response.strip()))
|
history.append((query, full_response.strip()))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
chat()
|
chat()
|
||||||
|
|
|
||||||
|
|
@ -6,41 +6,30 @@ from khaosz.config import (
|
||||||
TrainConfig,
|
TrainConfig,
|
||||||
)
|
)
|
||||||
from khaosz.model.transformer import Transformer
|
from khaosz.model.transformer import Transformer
|
||||||
from khaosz.data import (
|
from khaosz.data import DatasetLoader, BpeTokenizer
|
||||||
DatasetLoader,
|
|
||||||
BpeTokenizer
|
|
||||||
)
|
|
||||||
from khaosz.inference.generator import (
|
from khaosz.inference.generator import (
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
LoopGenerator,
|
LoopGenerator,
|
||||||
StreamGenerator,
|
StreamGenerator,
|
||||||
BatchGenerator,
|
BatchGenerator,
|
||||||
EmbeddingEncoder,
|
EmbeddingEncoder,
|
||||||
GeneratorFactory
|
GeneratorFactory,
|
||||||
)
|
|
||||||
from khaosz.trainer import (
|
|
||||||
Trainer,
|
|
||||||
StrategyFactory,
|
|
||||||
SchedulerFactory
|
|
||||||
)
|
)
|
||||||
|
from khaosz.trainer import Trainer, StrategyFactory, SchedulerFactory
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Transformer",
|
"Transformer",
|
||||||
|
|
||||||
"ModelConfig",
|
"ModelConfig",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
|
|
||||||
"DatasetLoader",
|
"DatasetLoader",
|
||||||
"BpeTokenizer",
|
"BpeTokenizer",
|
||||||
|
|
||||||
"GenerationRequest",
|
"GenerationRequest",
|
||||||
"LoopGenerator",
|
"LoopGenerator",
|
||||||
"StreamGenerator",
|
"StreamGenerator",
|
||||||
"BatchGenerator",
|
"BatchGenerator",
|
||||||
"EmbeddingEncoder",
|
"EmbeddingEncoder",
|
||||||
"GeneratorFactory",
|
"GeneratorFactory",
|
||||||
|
|
||||||
"Trainer",
|
"Trainer",
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
"SchedulerFactory"
|
"SchedulerFactory",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
from khaosz.config.model_config import ModelConfig
|
from khaosz.config.model_config import ModelConfig
|
||||||
from khaosz.config.param_config import BaseModelIO, ModelParameter
|
from khaosz.config.param_config import BaseModelIO, ModelParameter
|
||||||
from khaosz.config.schedule_config import (
|
from khaosz.config.schedule_config import (
|
||||||
ScheduleConfig,
|
ScheduleConfig,
|
||||||
CosineScheduleConfig,
|
CosineScheduleConfig,
|
||||||
SGDRScheduleConfig,
|
SGDRScheduleConfig,
|
||||||
ScheduleConfigFactory
|
ScheduleConfigFactory,
|
||||||
)
|
)
|
||||||
from khaosz.config.train_config import TrainConfig
|
from khaosz.config.train_config import TrainConfig
|
||||||
|
|
||||||
|
|
@ -13,14 +13,12 @@ __all__ = [
|
||||||
# Base I/O
|
# Base I/O
|
||||||
"BaseModelIO",
|
"BaseModelIO",
|
||||||
"ModelParameter",
|
"ModelParameter",
|
||||||
|
|
||||||
# Model configuration
|
# Model configuration
|
||||||
"ModelConfig",
|
"ModelConfig",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
|
|
||||||
# Schedule configuration
|
# Schedule configuration
|
||||||
"ScheduleConfig",
|
"ScheduleConfig",
|
||||||
"CosineScheduleConfig",
|
"CosineScheduleConfig",
|
||||||
"SGDRScheduleConfig",
|
"SGDRScheduleConfig",
|
||||||
"ScheduleConfigFactory",
|
"ScheduleConfigFactory",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -14,30 +14,29 @@ class ModelConfig:
|
||||||
norm_eps: Optional[float] = None
|
norm_eps: Optional[float] = None
|
||||||
dim_ffn: Optional[int] = None
|
dim_ffn: Optional[int] = None
|
||||||
tie_weight: Optional[bool] = None
|
tie_weight: Optional[bool] = None
|
||||||
|
|
||||||
# RoPE
|
# RoPE
|
||||||
max_len: Optional[int] = None
|
max_len: Optional[int] = None
|
||||||
rope_theta: Optional[float] = None
|
rope_theta: Optional[float] = None
|
||||||
|
|
||||||
# GQA
|
# GQA
|
||||||
n_heads: Optional[int] = None
|
n_heads: Optional[int] = None
|
||||||
n_kv_heads: Optional[int] = None
|
n_kv_heads: Optional[int] = None
|
||||||
use_qk_norm: Optional[bool] = None
|
use_qk_norm: Optional[bool] = None
|
||||||
use_gated_attention: Optional[bool] = None
|
use_gated_attention: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
def load(self, config_path: str) -> Self:
|
def load(self, config_path: str) -> Self:
|
||||||
config = {}
|
config = {}
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, "r") as f:
|
||||||
config.update(json.load(f))
|
config.update(json.load(f))
|
||||||
|
|
||||||
for key, value in config.items():
|
for key, value in config.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def save(self, config_path: str):
|
def save(self, config_path: str):
|
||||||
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
|
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_dict, f, indent=4)
|
json.dump(config_dict, f, indent=4)
|
||||||
|
|
|
||||||
|
|
@ -9,58 +9,57 @@ from khaosz.data.tokenizer import BpeTokenizer
|
||||||
from khaosz.config.model_config import ModelConfig
|
from khaosz.config.model_config import ModelConfig
|
||||||
from khaosz.model.transformer import Transformer
|
from khaosz.model.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseModelIO:
|
class BaseModelIO:
|
||||||
"""Base class for model I/O operations."""
|
"""Base class for model I/O operations."""
|
||||||
|
|
||||||
model: Optional[nn.Module] = field(
|
model: Optional[nn.Module] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Transformer model."}
|
||||||
metadata={"help": "Transformer model."}
|
|
||||||
)
|
)
|
||||||
tokenizer: BpeTokenizer = field(
|
tokenizer: BpeTokenizer = field(
|
||||||
default_factory=BpeTokenizer,
|
default_factory=BpeTokenizer, metadata={"help": "Tokenizer for the model."}
|
||||||
metadata={"help": "Tokenizer for the model."}
|
|
||||||
)
|
)
|
||||||
config: ModelConfig = field(
|
config: ModelConfig = field(
|
||||||
default_factory=ModelConfig,
|
default_factory=ModelConfig,
|
||||||
metadata={"help": "Transformer model configuration."}
|
metadata={"help": "Transformer model configuration."},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
||||||
"""Get standardized file paths for model components."""
|
"""Get standardized file paths for model components."""
|
||||||
dir_path = Path(directory)
|
dir_path = Path(directory)
|
||||||
return {
|
return {
|
||||||
"model": dir_path / "model.safetensors",
|
"model": dir_path / "model.safetensors",
|
||||||
"config": dir_path / "config.json",
|
"config": dir_path / "config.json",
|
||||||
"tokenizer": dir_path / "tokenizer.json"
|
"tokenizer": dir_path / "tokenizer.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
def save_components(self, save_dir: Union[str, Path]):
|
def save_components(self, save_dir: Union[str, Path]):
|
||||||
"""Save core model components."""
|
"""Save core model components."""
|
||||||
paths = self._get_file_paths(save_dir)
|
paths = self._get_file_paths(save_dir)
|
||||||
paths["model"].parent.mkdir(parents=True, exist_ok=True)
|
paths["model"].parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if self.model is not None:
|
if self.model is not None:
|
||||||
st.save_file(self.model.state_dict(), str(paths["model"]))
|
st.save_file(self.model.state_dict(), str(paths["model"]))
|
||||||
self.config.save(str(paths["config"]))
|
self.config.save(str(paths["config"]))
|
||||||
self.tokenizer.save(str(paths["tokenizer"]))
|
self.tokenizer.save(str(paths["tokenizer"]))
|
||||||
|
|
||||||
def load_components(self, load_dir: Union[str, Path]) -> Self:
|
def load_components(self, load_dir: Union[str, Path]) -> Self:
|
||||||
"""Load core model components."""
|
"""Load core model components."""
|
||||||
paths = self._get_file_paths(load_dir)
|
paths = self._get_file_paths(load_dir)
|
||||||
|
|
||||||
self.config.load(str(paths["config"]))
|
self.config.load(str(paths["config"]))
|
||||||
self.tokenizer.load(str(paths["tokenizer"]))
|
self.tokenizer.load(str(paths["tokenizer"]))
|
||||||
|
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
self.model = Transformer(self.config)
|
self.model = Transformer(self.config)
|
||||||
|
|
||||||
if paths["model"].exists():
|
if paths["model"].exists():
|
||||||
state_dict = st.load_file(str(paths["model"]))
|
state_dict = st.load_file(str(paths["model"]))
|
||||||
self.model.load_state_dict(state_dict)
|
self.model.load_state_dict(state_dict)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to(self, *args, **kwargs) -> "BaseModelIO":
|
def to(self, *args, **kwargs) -> "BaseModelIO":
|
||||||
"""Move model to device."""
|
"""Move model to device."""
|
||||||
if self.model is not None:
|
if self.model is not None:
|
||||||
|
|
@ -71,13 +70,12 @@ class BaseModelIO:
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelParameter(BaseModelIO):
|
class ModelParameter(BaseModelIO):
|
||||||
"""Container for model parameters with serialization capabilities."""
|
"""Container for model parameters with serialization capabilities."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]):
|
def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]):
|
||||||
instance.save_components(save_dir)
|
instance.save_components(save_dir)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, load_dir: Union[str, Path]) -> "ModelParameter":
|
def load(cls, load_dir: Union[str, Path]) -> "ModelParameter":
|
||||||
instance = cls()
|
instance = cls()
|
||||||
return instance.load_components(load_dir)
|
return instance.load_components(load_dir)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,35 +6,35 @@ from dataclasses import dataclass, field
|
||||||
@dataclass
|
@dataclass
|
||||||
class ScheduleConfig(ABC):
|
class ScheduleConfig(ABC):
|
||||||
"""Base configuration class for learning rate schedulers.
|
"""Base configuration class for learning rate schedulers.
|
||||||
|
|
||||||
Provides common validation and interface for all schedule types.
|
Provides common validation and interface for all schedule types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
schedule_type: str = field(
|
schedule_type: str = field(
|
||||||
default="cosine",
|
default="cosine",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Type of learning rate schedule.",
|
"help": "Type of learning rate schedule.",
|
||||||
"choices": ["cosine", "sgdr"]
|
"choices": ["cosine", "sgdr"],
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
warmup_steps: int = field(
|
warmup_steps: int = field(
|
||||||
default=1000,
|
default=1000, metadata={"help": "Number of warmup steps."}
|
||||||
metadata={"help": "Number of warmup steps."}
|
|
||||||
)
|
)
|
||||||
min_rate: float = field(
|
min_rate: float = field(
|
||||||
default=0.05,
|
default=0.05, metadata={"help": "Minimum learning rate multiplier."}
|
||||||
metadata={"help": "Minimum learning rate multiplier."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_kwargs(self) -> Dict[str, Any]:
|
def get_kwargs(self) -> Dict[str, Any]:
|
||||||
"""Get configuration kwargs for scheduler creation."""
|
"""Get configuration kwargs for scheduler creation."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
"""Validate configuration parameters."""
|
"""Validate configuration parameters."""
|
||||||
if self.warmup_steps < 0:
|
if self.warmup_steps < 0:
|
||||||
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
|
raise ValueError(
|
||||||
|
f"warmup_steps must be non-negative, got {self.warmup_steps}"
|
||||||
|
)
|
||||||
if not 0 <= self.min_rate <= 1:
|
if not 0 <= self.min_rate <= 1:
|
||||||
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
|
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
|
||||||
|
|
||||||
|
|
@ -42,44 +42,43 @@ class ScheduleConfig(ABC):
|
||||||
@dataclass
|
@dataclass
|
||||||
class CosineScheduleConfig(ScheduleConfig):
|
class CosineScheduleConfig(ScheduleConfig):
|
||||||
"""Cosine annealing learning rate schedule configuration."""
|
"""Cosine annealing learning rate schedule configuration."""
|
||||||
|
|
||||||
total_steps: int = field(
|
total_steps: int = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Total training steps for cosine schedule."}
|
||||||
metadata={"help": "Total training steps for cosine schedule."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self.schedule_type = "cosine"
|
self.schedule_type = "cosine"
|
||||||
self.validate()
|
self.validate()
|
||||||
|
|
||||||
def get_kwargs(self) -> Dict[str, Any]:
|
def get_kwargs(self) -> Dict[str, Any]:
|
||||||
if self.total_steps is None:
|
if self.total_steps is None:
|
||||||
raise ValueError("total_steps must be specified for cosine schedule")
|
raise ValueError("total_steps must be specified for cosine schedule")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"schedule_type": self.schedule_type,
|
"schedule_type": self.schedule_type,
|
||||||
"warmup_steps": self.warmup_steps,
|
"warmup_steps": self.warmup_steps,
|
||||||
"lr_decay_steps": self.total_steps - self.warmup_steps,
|
"lr_decay_steps": self.total_steps - self.warmup_steps,
|
||||||
"min_rate": self.min_rate
|
"min_rate": self.min_rate,
|
||||||
}
|
}
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
super().validate()
|
super().validate()
|
||||||
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
|
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
|
||||||
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
|
raise ValueError(
|
||||||
|
f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SGDRScheduleConfig(ScheduleConfig):
|
class SGDRScheduleConfig(ScheduleConfig):
|
||||||
"""Stochastic Gradient Descent with Warm Restarts schedule configuration."""
|
"""Stochastic Gradient Descent with Warm Restarts schedule configuration."""
|
||||||
|
|
||||||
cycle_length: int = field(
|
cycle_length: int = field(
|
||||||
default=1000,
|
default=1000, metadata={"help": "Length of the first cycle in steps."}
|
||||||
metadata={"help": "Length of the first cycle in steps."}
|
|
||||||
)
|
)
|
||||||
t_mult: int = field(
|
t_mult: int = field(
|
||||||
default=2,
|
default=2, metadata={"help": "Multiplier for cycle length growth."}
|
||||||
metadata={"help": "Multiplier for cycle length growth."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
|
@ -92,9 +91,9 @@ class SGDRScheduleConfig(ScheduleConfig):
|
||||||
"warmup_steps": self.warmup_steps,
|
"warmup_steps": self.warmup_steps,
|
||||||
"cycle_length": self.cycle_length,
|
"cycle_length": self.cycle_length,
|
||||||
"min_rate": self.min_rate,
|
"min_rate": self.min_rate,
|
||||||
"t_mult": self.t_mult
|
"t_mult": self.t_mult,
|
||||||
}
|
}
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
super().validate()
|
super().validate()
|
||||||
if self.cycle_length <= 0:
|
if self.cycle_length <= 0:
|
||||||
|
|
@ -105,33 +104,33 @@ class SGDRScheduleConfig(ScheduleConfig):
|
||||||
|
|
||||||
class ScheduleConfigFactory:
|
class ScheduleConfigFactory:
|
||||||
"""Factory class for creating ScheduleConfig instances.
|
"""Factory class for creating ScheduleConfig instances.
|
||||||
|
|
||||||
Supports both direct instantiation and factory creation methods.
|
Supports both direct instantiation and factory creation methods.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
# Direct creation
|
# Direct creation
|
||||||
config = CosineScheduleConfig(total_steps=10000)
|
config = CosineScheduleConfig(total_steps=10000)
|
||||||
|
|
||||||
# Factory method
|
# Factory method
|
||||||
config = ScheduleConfigFactory.create("cosine", total_steps=10000)
|
config = ScheduleConfigFactory.create("cosine", total_steps=10000)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = {
|
CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = {
|
||||||
"cosine": CosineScheduleConfig,
|
"cosine": CosineScheduleConfig,
|
||||||
"sgdr": SGDRScheduleConfig,
|
"sgdr": SGDRScheduleConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig:
|
def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig:
|
||||||
"""Create a schedule config instance.
|
"""Create a schedule config instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schedule_type: Type of schedule ("cosine", "sgdr")
|
schedule_type: Type of schedule ("cosine", "sgdr")
|
||||||
**kwargs: Arguments passed to the config constructor
|
**kwargs: Arguments passed to the config constructor
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ScheduleConfig instance
|
ScheduleConfig instance
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If schedule_type is not supported
|
ValueError: If schedule_type is not supported
|
||||||
"""
|
"""
|
||||||
|
|
@ -140,11 +139,11 @@ class ScheduleConfigFactory:
|
||||||
f"Unknown schedule type: '{schedule_type}'. "
|
f"Unknown schedule type: '{schedule_type}'. "
|
||||||
f"Supported types: {sorted(cls.CONFIG_MAP.keys())}"
|
f"Supported types: {sorted(cls.CONFIG_MAP.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
config_cls = cls.CONFIG_MAP[schedule_type]
|
config_cls = cls.CONFIG_MAP[schedule_type]
|
||||||
return config_cls(**kwargs)
|
return config_cls(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def available_types(cls) -> list:
|
def available_types(cls) -> list:
|
||||||
"""Return list of available schedule type names."""
|
"""Return list of available schedule type names."""
|
||||||
return list(cls.CONFIG_MAP.keys())
|
return list(cls.CONFIG_MAP.keys())
|
||||||
|
|
|
||||||
|
|
@ -10,127 +10,92 @@ from typing import Callable, List, Optional
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainConfig:
|
class TrainConfig:
|
||||||
# basic setting
|
# basic setting
|
||||||
model: nn.Module = field(
|
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
|
||||||
default=None,
|
strategy: str = field(default=None, metadata={"help": "Training strategy."})
|
||||||
metadata={"help": "Model for training."}
|
dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."})
|
||||||
)
|
|
||||||
strategy: str = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Training strategy."}
|
|
||||||
)
|
|
||||||
dataset: Dataset = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Dataset for training."}
|
|
||||||
)
|
|
||||||
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Optimizer factory for training."}
|
||||||
metadata={"help": "Optimizer factory for training."}
|
|
||||||
)
|
)
|
||||||
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Scheduler factory for training."}
|
||||||
metadata={"help": "Scheduler factory for training."}
|
|
||||||
)
|
|
||||||
n_epoch: int = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "Number of epochs for training."}
|
|
||||||
)
|
|
||||||
batch_size: int = field(
|
|
||||||
default=4,
|
|
||||||
metadata={"help": "Batch size for training."}
|
|
||||||
)
|
)
|
||||||
|
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
||||||
|
batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
|
||||||
accumulation_steps: int = field(
|
accumulation_steps: int = field(
|
||||||
default=1,
|
default=1, metadata={"help": "Number of iterations between steps."}
|
||||||
metadata={"help": "Number of iterations between steps."}
|
|
||||||
)
|
)
|
||||||
max_grad_norm: float = field(
|
max_grad_norm: float = field(
|
||||||
default=1.0,
|
default=1.0, metadata={"help": "Maximum gradient norm."}
|
||||||
metadata={"help": "Maximum gradient norm."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# checkpoint setting
|
# checkpoint setting
|
||||||
start_epoch: int = field(
|
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
|
||||||
default=0,
|
|
||||||
metadata={"help": "Start epoch for training."}
|
|
||||||
)
|
|
||||||
start_batch: int = field(
|
start_batch: int = field(
|
||||||
default=0,
|
default=0, metadata={"help": "Start batch iteration for training."}
|
||||||
metadata={"help": "Start batch iteration for training."}
|
|
||||||
)
|
)
|
||||||
checkpoint_dir: str = field(
|
ckpt_dir: str = field(
|
||||||
default="./checkpoint",
|
default="./checkpoint", metadata={"help": "Checkpoint directory."}
|
||||||
metadata={"help": "Checkpoint directory."}
|
|
||||||
)
|
)
|
||||||
checkpoint_interval: int = field(
|
ckpt_interval: int = field(
|
||||||
default=5000,
|
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
||||||
metadata={"help": "Number of iterations between checkpoints."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# dataloader setting
|
# dataloader setting
|
||||||
random_seed: int = field(
|
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
|
||||||
default=3407,
|
|
||||||
metadata={"help": "Random seed."}
|
|
||||||
)
|
|
||||||
num_workers: int = field(
|
num_workers: int = field(
|
||||||
default=0,
|
default=0, metadata={"help": "Number of workers for dataloader."}
|
||||||
metadata={"help": "Number of workers for dataloader."}
|
|
||||||
)
|
)
|
||||||
prefetch_factor: Optional[int] = field(
|
prefetch_factor: Optional[int] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Prefetch factor for dataloader."}
|
||||||
metadata={"help": "Prefetch factor for dataloader."}
|
|
||||||
)
|
)
|
||||||
pin_memory: bool = field(
|
pin_memory: bool = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Pin memory for dataloader."}
|
||||||
metadata={"help": "Pin memory for dataloader."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# distributed training
|
# distributed training
|
||||||
nprocs: int = field(
|
nprocs: int = field(
|
||||||
default=1,
|
default=1, metadata={"help": "Number of processes for distributed training."}
|
||||||
metadata={"help": "Number of processes for distributed training."}
|
|
||||||
)
|
)
|
||||||
backend: str = field(
|
backend: str = field(
|
||||||
default="nccl",
|
default="nccl", metadata={"help": "Distributed training backend."}
|
||||||
metadata={"help": "Distributed training backend."}
|
|
||||||
)
|
)
|
||||||
master_addr: str = field(
|
master_addr: str = field(
|
||||||
default="localhost",
|
default="localhost",
|
||||||
metadata={"help": "Master address for distributed training."}
|
metadata={"help": "Master address for distributed training."},
|
||||||
)
|
)
|
||||||
master_port: str = field(
|
master_port: str = field(
|
||||||
default="29500",
|
default="29500", metadata={"help": "Master port for distributed training."}
|
||||||
metadata={"help": "Master port for distributed training."}
|
|
||||||
)
|
)
|
||||||
parallel_wrapper: Optional[Callable] = field(
|
parallel_wrapper: Optional[Callable] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Parallel function for training."}
|
||||||
metadata={"help": "Parallel function for training."}
|
|
||||||
)
|
)
|
||||||
state_dict_fn: Optional[Callable] = field(
|
state_dict_fn: Optional[Callable] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Parallel function for state dict saving."}
|
||||||
metadata={"help": "Parallel function for state dict saving."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# others
|
# others
|
||||||
device_ids: Optional[List[int]] = field(
|
device_ids: Optional[List[int]] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Device ids for distributed training."}
|
||||||
metadata={"help": "Device ids for distributed training."}
|
|
||||||
)
|
)
|
||||||
device_type: str = field(
|
device_type: str = field(
|
||||||
default="cuda",
|
default="cuda", metadata={"help": "Device type for distributed training."}
|
||||||
metadata={"help": "Device type for distributed training."}
|
|
||||||
)
|
)
|
||||||
extra_kwargs: dict = field(
|
extra_kwargs: dict = field(
|
||||||
default_factory=dict,
|
default_factory=dict, metadata={"help": "Other arguments."}
|
||||||
metadata={"help": "Other arguments."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.validate()
|
self.validate()
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
required_fields = ["model", "strategy", "dataset", "optimizer_fn", "scheduler_fn"]
|
required_fields = [
|
||||||
|
"model",
|
||||||
|
"strategy",
|
||||||
|
"dataset",
|
||||||
|
"optimizer_fn",
|
||||||
|
"scheduler_fn",
|
||||||
|
]
|
||||||
|
|
||||||
for field_name in required_fields:
|
for field_name in required_fields:
|
||||||
if getattr(self, field_name) is None:
|
if getattr(self, field_name) is None:
|
||||||
raise ValueError(f"{field_name} is required.")
|
raise ValueError(f"{field_name} is required.")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
from khaosz.data.dataset import (
|
from khaosz.data.dataset import (
|
||||||
BaseDataset,
|
BaseDataset,
|
||||||
SEQDataset,
|
SEQDataset,
|
||||||
DPODataset,
|
DPODataset,
|
||||||
SFTDataset,
|
SFTDataset,
|
||||||
GRPODataset,
|
GRPODataset,
|
||||||
MultiSegmentFetcher,
|
MultiSegmentFetcher,
|
||||||
DatasetLoader,
|
DatasetLoader,
|
||||||
DatasetFactory
|
DatasetFactory,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khaosz.data.tokenizer import BpeTokenizer
|
from khaosz.data.tokenizer import BpeTokenizer
|
||||||
|
|
@ -15,21 +15,17 @@ from khaosz.data.sampler import ResumableDistributedSampler
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Base classes
|
# Base classes
|
||||||
"BaseDataset",
|
"BaseDataset",
|
||||||
|
|
||||||
# Dataset implementations
|
# Dataset implementations
|
||||||
"SEQDataset",
|
"SEQDataset",
|
||||||
"SFTDataset",
|
"SFTDataset",
|
||||||
"DPODataset",
|
"DPODataset",
|
||||||
"GRPODataset",
|
"GRPODataset",
|
||||||
|
|
||||||
# Fetchers
|
# Fetchers
|
||||||
"MultiSegmentFetcher",
|
"MultiSegmentFetcher",
|
||||||
|
|
||||||
# Factory (DatasetLoader is alias for backward compatibility)
|
# Factory (DatasetLoader is alias for backward compatibility)
|
||||||
"DatasetLoader",
|
"DatasetLoader",
|
||||||
"DatasetFactory",
|
"DatasetFactory",
|
||||||
|
|
||||||
# Tokenizer and sampler
|
# Tokenizer and sampler
|
||||||
"BpeTokenizer",
|
"BpeTokenizer",
|
||||||
"ResumableDistributedSampler"
|
"ResumableDistributedSampler",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -12,40 +12,42 @@ from typing import Callable, List, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
class BaseSegmentFetcher:
|
class BaseSegmentFetcher:
|
||||||
"""Fetches data segments across multiple tensor segments.
|
"""Fetches data segments across multiple tensor segments.
|
||||||
|
|
||||||
Maintains cumulative lengths for efficient range queries across
|
Maintains cumulative lengths for efficient range queries across
|
||||||
multiple discontinuous segments.
|
multiple discontinuous segments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, segments: List[Tensor]):
|
def __init__(self, segments: List[Tensor]):
|
||||||
self.segments = segments
|
self.segments = segments
|
||||||
self.cum_lengths = []
|
self.cum_lengths = []
|
||||||
|
|
||||||
total = 0
|
total = 0
|
||||||
for seg in segments:
|
for seg in segments:
|
||||||
total += torch.numel(seg)
|
total += torch.numel(seg)
|
||||||
self.cum_lengths.append(total)
|
self.cum_lengths.append(total)
|
||||||
|
|
||||||
self.total_length = total
|
self.total_length = total
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.total_length
|
return self.total_length
|
||||||
|
|
||||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
"""Fetch data in the range [begin_idx, end_idx).
|
"""Fetch data in the range [begin_idx, end_idx).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
begin_idx: Starting index (inclusive)
|
begin_idx: Starting index (inclusive)
|
||||||
end_idx: Ending index (exclusive)
|
end_idx: Ending index (exclusive)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Concatenated tensor of data in the specified range
|
Concatenated tensor of data in the specified range
|
||||||
"""
|
"""
|
||||||
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
|
if not (
|
||||||
|
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
|
||||||
|
):
|
||||||
raise ValueError("begin_idx or end_idx out of bounds")
|
raise ValueError("begin_idx or end_idx out of bounds")
|
||||||
if begin_idx >= end_idx:
|
if begin_idx >= end_idx:
|
||||||
return torch.tensor([], dtype=torch.long)
|
return torch.tensor([], dtype=torch.long)
|
||||||
|
|
||||||
# Find segment boundaries for the range
|
# Find segment boundaries for the range
|
||||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
||||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
||||||
|
|
@ -64,43 +66,44 @@ class BaseSegmentFetcher:
|
||||||
|
|
||||||
class MultiSegmentFetcher:
|
class MultiSegmentFetcher:
|
||||||
"""Manages multiple segment fetchers for different data keys.
|
"""Manages multiple segment fetchers for different data keys.
|
||||||
|
|
||||||
Each key corresponds to a different type of data (e.g., "sequence", "mask").
|
Each key corresponds to a different type of data (e.g., "sequence", "mask").
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, muti_segments: Dict):
|
def __init__(self, muti_segments: Dict):
|
||||||
self.muti_keys = list(muti_segments.keys())
|
self.muti_keys = list(muti_segments.keys())
|
||||||
self.muti_fetchers = {
|
self.muti_fetchers = {
|
||||||
key: BaseSegmentFetcher(segments)
|
key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items()
|
||||||
for key, segments in muti_segments.items()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Returns the minimum length across all fetchers."""
|
"""Returns the minimum length across all fetchers."""
|
||||||
len_list = [len(seg) for seg in self.muti_fetchers.values()]
|
len_list = [len(seg) for seg in self.muti_fetchers.values()]
|
||||||
return min(len_list)
|
return min(len_list)
|
||||||
|
|
||||||
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
|
def key_fetch(
|
||||||
|
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
|
||||||
|
) -> Dict:
|
||||||
"""Fetch data for specific keys.
|
"""Fetch data for specific keys.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
begin_idx: Starting index
|
begin_idx: Starting index
|
||||||
end_idx: Ending index
|
end_idx: Ending index
|
||||||
keys: Single key or list of keys to fetch
|
keys: Single key or list of keys to fetch
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of tensors if multiple keys, single tensor if one key
|
Dictionary of tensors if multiple keys, single tensor if one key
|
||||||
"""
|
"""
|
||||||
fetch_dict = {}
|
fetch_dict = {}
|
||||||
keys = [keys] if isinstance(keys, str) else keys
|
keys = [keys] if isinstance(keys, str) else keys
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
fetcher = self.muti_fetchers[key]
|
fetcher = self.muti_fetchers[key]
|
||||||
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
||||||
fetch_dict[key] = fetch_tensor
|
fetch_dict[key] = fetch_tensor
|
||||||
|
|
||||||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||||
|
|
||||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||||
"""Fetch all keys."""
|
"""Fetch all keys."""
|
||||||
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
||||||
|
|
@ -108,10 +111,10 @@ class MultiSegmentFetcher:
|
||||||
|
|
||||||
class BaseDataset(Dataset, ABC):
|
class BaseDataset(Dataset, ABC):
|
||||||
"""Abstract base class for all dataset types.
|
"""Abstract base class for all dataset types.
|
||||||
|
|
||||||
Implements common functionality for window-based data fetching.
|
Implements common functionality for window-based data fetching.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.segments = {}
|
self.segments = {}
|
||||||
|
|
@ -122,38 +125,38 @@ class BaseDataset(Dataset, ABC):
|
||||||
|
|
||||||
def load(self, load_path: str):
|
def load(self, load_path: str):
|
||||||
"""Load dataset from HDF5 file.
|
"""Load dataset from HDF5 file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
load_path: Path to the HDF5 data file
|
load_path: Path to the HDF5 data file
|
||||||
"""
|
"""
|
||||||
self.segments = load_h5(load_path)
|
self.segments = load_h5(load_path)
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
self.total_samples = len(self.fetcher)
|
self.total_samples = len(self.fetcher)
|
||||||
|
|
||||||
def get_index(self, index: int) -> tuple:
|
def get_index(self, index: int) -> tuple:
|
||||||
"""Calculate begin and end indices for a sample.
|
"""Calculate begin and end indices for a sample.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
index: Sample index
|
index: Sample index
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (begin_idx, end_idx)
|
Tuple of (begin_idx, end_idx)
|
||||||
"""
|
"""
|
||||||
assert self.total_samples > self.window_size
|
assert self.total_samples > self.window_size
|
||||||
|
|
||||||
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
|
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
|
||||||
end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
|
end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
|
||||||
|
|
||||||
return begin_idx, end_idx
|
return begin_idx, end_idx
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
"""Get a single sample by index.
|
"""Get a single sample by index.
|
||||||
|
|
||||||
Must be implemented by subclasses.
|
Must be implemented by subclasses.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
assert self.total_samples is not None
|
assert self.total_samples is not None
|
||||||
if self.total_samples <= self.window_size:
|
if self.total_samples <= self.window_size:
|
||||||
|
|
@ -163,48 +166,50 @@ class BaseDataset(Dataset, ABC):
|
||||||
|
|
||||||
class DatasetFactory:
|
class DatasetFactory:
|
||||||
"""Factory class for creating dataset instances.
|
"""Factory class for creating dataset instances.
|
||||||
|
|
||||||
Supports decorator-based registration for extensible dataset types.
|
Supports decorator-based registration for extensible dataset types.
|
||||||
All default dataset types (seq, sft, dpo, grpo) are registered automatically
|
All default dataset types (seq, sft, dpo, grpo) are registered automatically
|
||||||
when their classes are defined with the decorator.
|
when their classes are defined with the decorator.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
@DatasetFactory.register("custom")
|
@DatasetFactory.register("custom")
|
||||||
class CustomDataset(BaseDataset):
|
class CustomDataset(BaseDataset):
|
||||||
...
|
...
|
||||||
|
|
||||||
dataset = DatasetFactory.create("custom", window_size, stride)
|
dataset = DatasetFactory.create("custom", window_size, stride)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"})
|
SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"})
|
||||||
DATASET_MAP: Dict[str, type] = {}
|
DATASET_MAP: Dict[str, type] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(cls, name: str):
|
def register(cls, name: str):
|
||||||
"""Decorator to register a new dataset class.
|
"""Decorator to register a new dataset class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Registration name for the dataset type
|
name: Registration name for the dataset type
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Decorator function that registers the dataset class
|
Decorator function that registers the dataset class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(dataset_cls: type) -> type:
|
def decorator(dataset_cls: type) -> type:
|
||||||
if not issubclass(dataset_cls, BaseDataset):
|
if not issubclass(dataset_cls, BaseDataset):
|
||||||
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
||||||
cls.DATASET_MAP[name] = dataset_cls
|
cls.DATASET_MAP[name] = dataset_cls
|
||||||
return dataset_cls
|
return dataset_cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset:
|
def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset:
|
||||||
"""Create a dataset instance.
|
"""Create a dataset instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
||||||
window_size: Window size for data sampling
|
window_size: Window size for data sampling
|
||||||
stride: Stride between consecutive samples
|
stride: Stride between consecutive samples
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset instance
|
Dataset instance
|
||||||
"""
|
"""
|
||||||
|
|
@ -213,36 +218,42 @@ class DatasetFactory:
|
||||||
f"Unknown dataset type: '{train_type}'. "
|
f"Unknown dataset type: '{train_type}'. "
|
||||||
f"Supported types: {sorted(cls.SUPPORTED_TYPES)}"
|
f"Supported types: {sorted(cls.SUPPORTED_TYPES)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if train_type not in cls.DATASET_MAP:
|
if train_type not in cls.DATASET_MAP:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Dataset type '{train_type}' is supported but not yet implemented."
|
f"Dataset type '{train_type}' is supported but not yet implemented."
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_cls = cls.DATASET_MAP[train_type]
|
dataset_cls = cls.DATASET_MAP[train_type]
|
||||||
return dataset_cls(window_size, stride)
|
return dataset_cls(window_size, stride)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, train_type: str, load_path: str, window_size: int, stride: Optional[int] = None) -> BaseDataset:
|
def load(
|
||||||
|
cls,
|
||||||
|
train_type: str,
|
||||||
|
load_path: str,
|
||||||
|
window_size: int,
|
||||||
|
stride: Optional[int] = None,
|
||||||
|
) -> BaseDataset:
|
||||||
"""Create and load a dataset in one step.
|
"""Create and load a dataset in one step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_type: Type of training dataset
|
train_type: Type of training dataset
|
||||||
load_path: Path to the data file
|
load_path: Path to the data file
|
||||||
window_size: Window size for data sampling
|
window_size: Window size for data sampling
|
||||||
stride: Stride between consecutive samples (default: same as window_size)
|
stride: Stride between consecutive samples (default: same as window_size)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Loaded dataset instance
|
Loaded dataset instance
|
||||||
"""
|
"""
|
||||||
if stride is None:
|
if stride is None:
|
||||||
stride = window_size
|
stride = window_size
|
||||||
|
|
||||||
dataset = cls.create(train_type, window_size, stride)
|
dataset = cls.create(train_type, window_size, stride)
|
||||||
dataset.load(load_path)
|
dataset.load(load_path)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def available_types(cls) -> list:
|
def available_types(cls) -> list:
|
||||||
"""Return list of registered dataset type names."""
|
"""Return list of registered dataset type names."""
|
||||||
|
|
@ -256,46 +267,50 @@ class DatasetFactory:
|
||||||
@DatasetFactory.register("seq")
|
@DatasetFactory.register("seq")
|
||||||
class SEQDataset(BaseDataset):
|
class SEQDataset(BaseDataset):
|
||||||
"""Dataset for sequential next-token prediction training."""
|
"""Dataset for sequential next-token prediction training."""
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
||||||
|
|
||||||
return {"input_ids": x, "target_ids": y}
|
return {"input_ids": x, "target_ids": y}
|
||||||
|
|
||||||
|
|
||||||
@DatasetFactory.register("sft")
|
@DatasetFactory.register("sft")
|
||||||
class SFTDataset(BaseDataset):
|
class SFTDataset(BaseDataset):
|
||||||
"""Dataset for supervised fine-tuning with loss masking."""
|
"""Dataset for supervised fine-tuning with loss masking."""
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
|
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
|
||||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool)
|
dtype=torch.long
|
||||||
|
)
|
||||||
|
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
|
||||||
|
dtype=torch.bool
|
||||||
|
)
|
||||||
|
|
||||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
||||||
|
|
||||||
|
|
||||||
@DatasetFactory.register("dpo")
|
@DatasetFactory.register("dpo")
|
||||||
class DPODataset(BaseDataset):
|
class DPODataset(BaseDataset):
|
||||||
"""Dataset for Direct Preference Optimization training."""
|
"""Dataset for Direct Preference Optimization training."""
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
|
|
@ -304,25 +319,34 @@ class DPODataset(BaseDataset):
|
||||||
|
|
||||||
def __getitem__(self, index: int):
|
def __getitem__(self, index: int):
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
|
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
|
||||||
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
|
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
|
||||||
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool)
|
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(
|
||||||
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool)
|
dtype=torch.bool
|
||||||
|
)
|
||||||
|
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(
|
||||||
|
dtype=torch.bool
|
||||||
|
)
|
||||||
|
|
||||||
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
return {
|
||||||
|
"chosen": chosen,
|
||||||
|
"rejected": rejected,
|
||||||
|
"chosen_mask": chosen_mask,
|
||||||
|
"rejected_mask": rejected_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@DatasetFactory.register("grpo")
|
@DatasetFactory.register("grpo")
|
||||||
class GRPODataset(BaseDataset):
|
class GRPODataset(BaseDataset):
|
||||||
"""Dataset for Group Relative Policy Optimization training."""
|
"""Dataset for Group Relative Policy Optimization training."""
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
|
|
@ -330,8 +354,13 @@ class GRPODataset(BaseDataset):
|
||||||
responses = self._fetch_data(begin_idx, end_idx, "responses")
|
responses = self._fetch_data(begin_idx, end_idx, "responses")
|
||||||
masks = self._fetch_data(begin_idx, end_idx, "masks")
|
masks = self._fetch_data(begin_idx, end_idx, "masks")
|
||||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||||
|
|
||||||
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
|
return {
|
||||||
|
"prompts": prompts,
|
||||||
|
"responses": responses,
|
||||||
|
"masks": masks,
|
||||||
|
"rewards": rewards,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias
|
# Backward compatibility alias
|
||||||
|
|
|
||||||
|
|
@ -7,45 +7,45 @@ from typing import Optional
|
||||||
|
|
||||||
class ResumableDistributedSampler(Sampler[int]):
|
class ResumableDistributedSampler(Sampler[int]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_source: Dataset,
|
data_source: Dataset,
|
||||||
start_epoch: int=0,
|
start_epoch: int = 0,
|
||||||
start_iter: int=0,
|
start_iter: int = 0,
|
||||||
seed: int=42,
|
seed: int = 42,
|
||||||
drop_last: bool=False,
|
drop_last: bool = False,
|
||||||
shuffle: bool=True,
|
shuffle: bool = True,
|
||||||
process_group: Optional[dist.ProcessGroup]=None,
|
process_group: Optional[dist.ProcessGroup] = None,
|
||||||
):
|
):
|
||||||
self.epoch = start_epoch
|
self.epoch = start_epoch
|
||||||
self.iter = start_iter
|
self.iter = start_iter
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.num_samples = len(data_source)
|
self.num_samples = len(data_source)
|
||||||
|
|
||||||
if process_group is not None:
|
if process_group is not None:
|
||||||
# input process group
|
# input process group
|
||||||
self.rank = dist.get_rank(process_group)
|
self.rank = dist.get_rank(process_group)
|
||||||
self.num_replicas = dist.get_world_size(process_group)
|
self.num_replicas = dist.get_world_size(process_group)
|
||||||
|
|
||||||
elif dist.is_available() and dist.is_initialized():
|
elif dist.is_available() and dist.is_initialized():
|
||||||
# use default process group
|
# use default process group
|
||||||
process_group = dist.group.WORLD
|
process_group = dist.group.WORLD
|
||||||
self.rank = dist.get_rank()
|
self.rank = dist.get_rank()
|
||||||
self.num_replicas = dist.get_world_size()
|
self.num_replicas = dist.get_world_size()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# single process
|
# single process
|
||||||
self.rank = 0
|
self.rank = 0
|
||||||
self.num_replicas = 1
|
self.num_replicas = 1
|
||||||
|
|
||||||
self.drop_last = drop_last
|
self.drop_last = drop_last
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
|
|
||||||
offset = 0 if drop_last else self.num_replicas - 1
|
offset = 0 if drop_last else self.num_replicas - 1
|
||||||
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
||||||
self.total_size = self.num_samples_per_replica * self.num_replicas
|
self.total_size = self.num_samples_per_replica * self.num_replicas
|
||||||
|
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
def _get_indices(self):
|
def _get_indices(self):
|
||||||
if self.shuffle:
|
if self.shuffle:
|
||||||
generator = torch.Generator()
|
generator = torch.Generator()
|
||||||
|
|
@ -53,26 +53,26 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
indices = torch.randperm(self.num_samples, generator=generator).tolist()
|
indices = torch.randperm(self.num_samples, generator=generator).tolist()
|
||||||
else:
|
else:
|
||||||
indices = torch.arange(self.num_samples).tolist()
|
indices = torch.arange(self.num_samples).tolist()
|
||||||
|
|
||||||
if not self.drop_last and self.num_samples < self.total_size:
|
if not self.drop_last and self.num_samples < self.total_size:
|
||||||
padding_size = self.total_size - len(indices)
|
padding_size = self.total_size - len(indices)
|
||||||
indices += indices[:padding_size]
|
indices += indices[:padding_size]
|
||||||
|
|
||||||
local_indices = indices[self.rank:self.total_size:self.num_replicas]
|
local_indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||||
|
|
||||||
self.iter = self.iter % self.num_samples_per_replica
|
self.iter = self.iter % self.num_samples_per_replica
|
||||||
self._indices = local_indices[self.iter:]
|
self._indices = local_indices[self.iter :]
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
if self._indices is None:
|
if self._indices is None:
|
||||||
self._get_indices()
|
self._get_indices()
|
||||||
|
|
||||||
for i in self._indices:
|
for i in self._indices:
|
||||||
self.iter += 1
|
self.iter += 1
|
||||||
yield i
|
yield i
|
||||||
|
|
||||||
self.epoch += 1
|
self.epoch += 1
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples_per_replica
|
return self.num_samples_per_replica
|
||||||
|
|
|
||||||
|
|
@ -10,24 +10,26 @@ from torch import Tensor
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
from khaosz.parallel.setup import get_rank
|
from khaosz.parallel.setup import get_rank
|
||||||
|
|
||||||
|
|
||||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
||||||
with h5py.File(full_file_path, 'w') as f:
|
with h5py.File(full_file_path, "w") as f:
|
||||||
for key, tensors in tensor_group.items():
|
for key, tensors in tensor_group.items():
|
||||||
grp = f.create_group(key)
|
grp = f.create_group(key)
|
||||||
for idx, tensor in enumerate(tensors):
|
for idx, tensor in enumerate(tensors):
|
||||||
arr = tensor.cpu().numpy()
|
arr = tensor.cpu().numpy()
|
||||||
grp.create_dataset(f'data_{idx}', data=arr)
|
grp.create_dataset(f"data_{idx}", data=arr)
|
||||||
|
|
||||||
|
|
||||||
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||||
tensor_group: Dict[str, List[Tensor]] = {}
|
tensor_group: Dict[str, List[Tensor]] = {}
|
||||||
|
|
||||||
root_path = Path(file_path)
|
root_path = Path(file_path)
|
||||||
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||||
|
|
||||||
for h5_file in h5_files:
|
for h5_file in h5_files:
|
||||||
with h5py.File(h5_file, 'r') as f:
|
with h5py.File(h5_file, "r") as f:
|
||||||
for key in f.keys():
|
for key in f.keys():
|
||||||
grp = f[key]
|
grp = f[key]
|
||||||
dsets = []
|
dsets = []
|
||||||
|
|
@ -37,7 +39,7 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||||
if share_memory:
|
if share_memory:
|
||||||
tensor = tensor.share_memory_()
|
tensor = tensor.share_memory_()
|
||||||
dsets.append(tensor)
|
dsets.append(tensor)
|
||||||
|
|
||||||
if tensor_group.get(key) is None:
|
if tensor_group.get(key) is None:
|
||||||
tensor_group[key] = []
|
tensor_group[key] = []
|
||||||
tensor_group[key].extend(dsets)
|
tensor_group[key].extend(dsets)
|
||||||
|
|
@ -60,7 +62,7 @@ class Checkpoint:
|
||||||
self,
|
self,
|
||||||
save_dir: str,
|
save_dir: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
save_path = Path(save_dir)
|
save_path = Path(save_dir)
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
@ -72,7 +74,7 @@ class Checkpoint:
|
||||||
}
|
}
|
||||||
with open(save_path / "meta.json", "w") as f:
|
with open(save_path / "meta.json", "w") as f:
|
||||||
json.dump(meta, f, indent=2)
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
st.save_file(self.state_dict, save_path / f"state_dict.safetensors")
|
st.save_file(self.state_dict, save_path / f"state_dict.safetensors")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -83,7 +85,7 @@ class Checkpoint:
|
||||||
|
|
||||||
rank = get_rank()
|
rank = get_rank()
|
||||||
save_path = Path(save_dir)
|
save_path = Path(save_dir)
|
||||||
|
|
||||||
meta = {}
|
meta = {}
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
with open(Path(save_dir) / "meta.json", "r") as f:
|
with open(Path(save_dir) / "meta.json", "r") as f:
|
||||||
|
|
@ -100,4 +102,4 @@ class Checkpoint:
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
epoch=meta["epoch"],
|
epoch=meta["epoch"],
|
||||||
iteration=meta["iteration"],
|
iteration=meta["iteration"],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -9,34 +9,46 @@ class BpeTokenizer:
|
||||||
def __init__(self, path=None):
|
def __init__(self, path=None):
|
||||||
self._control_tokens = ["<bos>", "<eos>", "<pad>"]
|
self._control_tokens = ["<bos>", "<eos>", "<pad>"]
|
||||||
self._special_tokens = ["<|im_start|>", "<|im_end|>"]
|
self._special_tokens = ["<|im_start|>", "<|im_end|>"]
|
||||||
|
|
||||||
model = BPE()
|
model = BPE()
|
||||||
self._tokenizer = Tokenizer(model)
|
self._tokenizer = Tokenizer(model)
|
||||||
self._tokenizer.normalizer = normalizers.Sequence([
|
self._tokenizer.normalizer = normalizers.Sequence(
|
||||||
normalizers.NFC(),
|
[normalizers.NFC(), normalizers.Strip()]
|
||||||
normalizers.Strip()
|
)
|
||||||
])
|
|
||||||
|
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||||
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
[
|
||||||
pre_tokenizers.UnicodeScripts(),
|
pre_tokenizers.UnicodeScripts(),
|
||||||
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True)
|
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
self._tokenizer.decoder = decoders.ByteLevel()
|
self._tokenizer.decoder = decoders.ByteLevel()
|
||||||
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
|
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
|
||||||
|
|
||||||
if path is not None:
|
if path is not None:
|
||||||
self._tokenizer = Tokenizer.from_file(path)
|
self._tokenizer = Tokenizer.from_file(path)
|
||||||
|
|
||||||
def _prepare_trainer(self, vocab_size: int, min_freq: int, reserved_token_size: int, max_token_length=18) -> tuple:
|
def _prepare_trainer(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
min_freq: int,
|
||||||
|
reserved_token_size: int,
|
||||||
|
max_token_length=18,
|
||||||
|
) -> tuple:
|
||||||
assert reserved_token_size > len(self._special_tokens)
|
assert reserved_token_size > len(self._special_tokens)
|
||||||
reserved_tokens = [f"<|reserve{i:02d}|>" for i in range(reserved_token_size - len(self._special_tokens))]
|
reserved_tokens = [
|
||||||
detail_vocab_size = vocab_size - (len(reserved_tokens) + len(self._special_tokens))
|
f"<|reserve{i:02d}|>"
|
||||||
|
for i in range(reserved_token_size - len(self._special_tokens))
|
||||||
|
]
|
||||||
|
detail_vocab_size = vocab_size - (
|
||||||
|
len(reserved_tokens) + len(self._special_tokens)
|
||||||
|
)
|
||||||
|
|
||||||
alphabet = pre_tokenizers.ByteLevel.alphabet()
|
alphabet = pre_tokenizers.ByteLevel.alphabet()
|
||||||
min_size = len(alphabet) + len(self._control_tokens)
|
min_size = len(alphabet) + len(self._control_tokens)
|
||||||
assert detail_vocab_size > min_size
|
assert detail_vocab_size > min_size
|
||||||
|
|
||||||
trainer = BpeTrainer(
|
trainer = BpeTrainer(
|
||||||
vocab_size=detail_vocab_size,
|
vocab_size=detail_vocab_size,
|
||||||
min_frequency=min_freq,
|
min_frequency=min_freq,
|
||||||
|
|
@ -46,61 +58,74 @@ class BpeTokenizer:
|
||||||
initial_alphabet=alphabet,
|
initial_alphabet=alphabet,
|
||||||
show_progress=True,
|
show_progress=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return trainer, detail_vocab_size, reserved_tokens
|
return trainer, detail_vocab_size, reserved_tokens
|
||||||
|
|
||||||
def train(self, files, vocab_size, min_freq, reserved_token_size=100):
|
def train(self, files, vocab_size, min_freq, reserved_token_size=100):
|
||||||
trainer, _, reserved_tokens = self._prepare_trainer(
|
trainer, _, reserved_tokens = self._prepare_trainer(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
min_freq=min_freq,
|
min_freq=min_freq,
|
||||||
reserved_token_size=reserved_token_size
|
reserved_token_size=reserved_token_size,
|
||||||
)
|
)
|
||||||
self._tokenizer.train(files=files, trainer=trainer)
|
self._tokenizer.train(files=files, trainer=trainer)
|
||||||
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
|
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
|
||||||
|
|
||||||
def train_from_iterator(self, iterator, vocab_size, min_freq, reserved_token_size=100):
|
def train_from_iterator(
|
||||||
|
self, iterator, vocab_size, min_freq, reserved_token_size=100
|
||||||
|
):
|
||||||
trainer, _, reserved_tokens = self._prepare_trainer(
|
trainer, _, reserved_tokens = self._prepare_trainer(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
min_freq=min_freq,
|
min_freq=min_freq,
|
||||||
reserved_token_size=reserved_token_size
|
reserved_token_size=reserved_token_size,
|
||||||
)
|
)
|
||||||
self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer)
|
self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer)
|
||||||
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
|
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
|
||||||
|
|
||||||
def save(self, path):
|
def save(self, path):
|
||||||
self._tokenizer.save(path)
|
self._tokenizer.save(path)
|
||||||
|
|
||||||
def load(self, path):
|
def load(self, path):
|
||||||
self._tokenizer = Tokenizer.from_file(path)
|
self._tokenizer = Tokenizer.from_file(path)
|
||||||
|
|
||||||
def encode(self, tokens: Union[str, List[str]], out_ids: bool=True, add_special_tokens: bool=False) -> List:
|
def encode(
|
||||||
|
self,
|
||||||
|
tokens: Union[str, List[str]],
|
||||||
|
out_ids: bool = True,
|
||||||
|
add_special_tokens: bool = False,
|
||||||
|
) -> List:
|
||||||
if isinstance(tokens, str):
|
if isinstance(tokens, str):
|
||||||
encoded: Encoding = self._tokenizer.encode(tokens, add_special_tokens=add_special_tokens)
|
encoded: Encoding = self._tokenizer.encode(
|
||||||
|
tokens, add_special_tokens=add_special_tokens
|
||||||
|
)
|
||||||
return encoded.ids if out_ids else encoded.tokens
|
return encoded.ids if out_ids else encoded.tokens
|
||||||
elif isinstance(tokens, list):
|
elif isinstance(tokens, list):
|
||||||
encoded_list: List[Encoding] = self._tokenizer.encode_batch(tokens, add_special_tokens=add_special_tokens)
|
encoded_list: List[Encoding] = self._tokenizer.encode_batch(
|
||||||
return [encoded.ids if out_ids else encoded.tokens for encoded in encoded_list]
|
tokens, add_special_tokens=add_special_tokens
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
|
||||||
|
]
|
||||||
|
|
||||||
def decode(self, tokens: List[int], skip_special_tokens: bool=True) -> str:
|
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
|
||||||
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self._tokenizer.get_vocab_size()
|
return self._tokenizer.get_vocab_size()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stop_ids(self) -> List[int]:
|
def stop_ids(self) -> List[int]:
|
||||||
stop_token = self._control_tokens + self._special_tokens
|
stop_token = self._control_tokens + self._special_tokens
|
||||||
stop_ids = [self._tokenizer.token_to_id(token) for token in stop_token]
|
stop_ids = [self._tokenizer.token_to_id(token) for token in stop_token]
|
||||||
return stop_ids
|
return stop_ids
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bos_id(self) -> int:
|
def bos_id(self) -> int:
|
||||||
return self._tokenizer.token_to_id("<bos>")
|
return self._tokenizer.token_to_id("<bos>")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def eos_id(self) -> int:
|
def eos_id(self) -> int:
|
||||||
return self._tokenizer.token_to_id("<eos>")
|
return self._tokenizer.token_to_id("<eos>")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pad_id(self) -> int:
|
def pad_id(self) -> int:
|
||||||
return self._tokenizer.token_to_id("<pad>")
|
return self._tokenizer.token_to_id("<pad>")
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from khaosz.inference.generator import (
|
||||||
StreamGenerator,
|
StreamGenerator,
|
||||||
BatchGenerator,
|
BatchGenerator,
|
||||||
EmbeddingEncoder,
|
EmbeddingEncoder,
|
||||||
GeneratorFactory
|
GeneratorFactory,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -19,11 +19,10 @@ __all__ = [
|
||||||
"GeneratorCore",
|
"GeneratorCore",
|
||||||
"EmbeddingEncoderCore",
|
"EmbeddingEncoderCore",
|
||||||
"KVCacheManager",
|
"KVCacheManager",
|
||||||
|
|
||||||
"GenerationRequest",
|
"GenerationRequest",
|
||||||
"LoopGenerator",
|
"LoopGenerator",
|
||||||
"StreamGenerator",
|
"StreamGenerator",
|
||||||
"BatchGenerator",
|
"BatchGenerator",
|
||||||
"EmbeddingEncoder",
|
"EmbeddingEncoder",
|
||||||
"GeneratorFactory"
|
"GeneratorFactory",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, List, Tuple, Union, Optional, Self
|
from typing import Any, Callable, List, Tuple, Union, Optional, Self
|
||||||
from khaosz.config import ModelParameter, ModelConfig
|
from khaosz.config import ModelParameter, ModelConfig
|
||||||
|
|
@ -12,58 +12,61 @@ def apply_sampling_strategies(
|
||||||
temperature: float,
|
temperature: float,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
filter_value: float = -float("inf")
|
filter_value: float = -float("inf"),
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Apply sampling strategies to the logits tensor.
|
Apply sampling strategies to the logits tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits (Tensor): The logits tensor.
|
logits (Tensor): The logits tensor.
|
||||||
temperature (float): The temperature parameter.
|
temperature (float): The temperature parameter.
|
||||||
top_k (int): The top-k parameter.
|
top_k (int): The top-k parameter.
|
||||||
top_p (float): The top-p parameter.
|
top_p (float): The top-p parameter.
|
||||||
filter_value (float, optional): The filter value. Defaults to -float("inf").
|
filter_value (float, optional): The filter value. Defaults to -float("inf").
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: The sampled logits tensor.
|
Tensor: The sampled logits tensor.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if temperature != 1.0:
|
if temperature != 1.0:
|
||||||
logits = logits / temperature
|
logits = logits / temperature
|
||||||
|
|
||||||
if top_k > 0:
|
if top_k > 0:
|
||||||
top_k = min(top_k, logits.size(-1))
|
top_k = min(top_k, logits.size(-1))
|
||||||
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
||||||
logits[indices_to_remove] = filter_value
|
logits[indices_to_remove] = filter_value
|
||||||
|
|
||||||
if top_p < 1.0:
|
if top_p < 1.0:
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
|
||||||
sorted_indices_to_remove = cumulative_probs > top_p
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||||
sorted_indices_to_remove[..., 0] = 0
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||||
indices_to_remove.scatter_(
|
indices_to_remove.scatter_(
|
||||||
dim=1,
|
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||||
index=sorted_indices,
|
|
||||||
src=sorted_indices_to_remove
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logits[indices_to_remove] = filter_value
|
logits[indices_to_remove] = filter_value
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def disable_random_init():
|
def disable_random_init():
|
||||||
init_functions = [
|
init_functions = [
|
||||||
'xavier_normal_', 'xavier_uniform_',
|
"xavier_normal_",
|
||||||
'kaiming_normal_', 'kaiming_uniform_',
|
"xavier_uniform_",
|
||||||
'zeros_', 'ones_', 'constant_',
|
"kaiming_normal_",
|
||||||
'normal_', 'uniform_'
|
"kaiming_uniform_",
|
||||||
|
"zeros_",
|
||||||
|
"ones_",
|
||||||
|
"constant_",
|
||||||
|
"normal_",
|
||||||
|
"uniform_",
|
||||||
]
|
]
|
||||||
original_funcs = {}
|
original_funcs = {}
|
||||||
for name in init_functions:
|
for name in init_functions:
|
||||||
|
|
@ -82,7 +85,7 @@ class GeneratorCore:
|
||||||
self.model = parameter.model
|
self.model = parameter.model
|
||||||
self.tokenizer = parameter.tokenizer
|
self.tokenizer = parameter.tokenizer
|
||||||
self.config = parameter.config
|
self.config = parameter.config
|
||||||
|
|
||||||
def generate_iterator(
|
def generate_iterator(
|
||||||
self,
|
self,
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
|
|
@ -91,18 +94,18 @@ class GeneratorCore:
|
||||||
top_p: float,
|
top_p: float,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
||||||
start_pos: int = 0
|
start_pos: int = 0,
|
||||||
)-> Tuple[Tensor, int]:
|
) -> Tuple[Tensor, int]:
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
|
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
|
||||||
logits = outputs["logits"][:, -1, :]
|
logits = outputs["logits"][:, -1, :]
|
||||||
cache_increase = input_ids.size(-1)
|
cache_increase = input_ids.size(-1)
|
||||||
|
|
||||||
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
|
||||||
probs = torch.softmax(logits, dim=-1)
|
probs = torch.softmax(logits, dim=-1)
|
||||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
return next_token_id, cache_increase
|
return next_token_id, cache_increase
|
||||||
|
|
||||||
def generate_loop(
|
def generate_loop(
|
||||||
|
|
@ -115,14 +118,21 @@ class GeneratorCore:
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
||||||
start_pos: int = 0,
|
start_pos: int = 0,
|
||||||
callback: Optional[Callable[..., Any]] = None
|
callback: Optional[Callable[..., Any]] = None,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
cur_cache_pos = start_pos
|
cur_cache_pos = start_pos
|
||||||
|
|
||||||
for _ in range(len(ids), self.config.max_len):
|
for _ in range(len(ids), self.config.max_len):
|
||||||
next_token_id, cache_increase = self.generate_iterator(
|
next_token_id, cache_increase = self.generate_iterator(
|
||||||
input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos)
|
input_ids,
|
||||||
|
temperature,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
attn_mask,
|
||||||
|
kv_caches,
|
||||||
|
cur_cache_pos,
|
||||||
|
)
|
||||||
|
|
||||||
input_ids = next_token_id
|
input_ids = next_token_id
|
||||||
ids.append(next_token_id.item())
|
ids.append(next_token_id.item())
|
||||||
cur_cache_pos += cache_increase
|
cur_cache_pos += cache_increase
|
||||||
|
|
@ -132,9 +142,9 @@ class GeneratorCore:
|
||||||
|
|
||||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||||
break
|
break
|
||||||
|
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def to(self, *args, **kargs) -> Self:
|
def to(self, *args, **kargs) -> Self:
|
||||||
self.model.to(*args, **kargs)
|
self.model.to(*args, **kargs)
|
||||||
return self
|
return self
|
||||||
|
|
@ -145,32 +155,35 @@ class EmbeddingEncoderCore:
|
||||||
self.model = parameter.model
|
self.model = parameter.model
|
||||||
self.tokenizer = parameter.tokenizer
|
self.tokenizer = parameter.tokenizer
|
||||||
self.config = parameter.config
|
self.config = parameter.config
|
||||||
|
|
||||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||||
with_batch = isinstance(sentence, list)
|
with_batch = isinstance(sentence, list)
|
||||||
ids = self.tokenizer.encode(sentence)
|
ids = self.tokenizer.encode(sentence)
|
||||||
batch_ids = ids if with_batch else [ids]
|
batch_ids = ids if with_batch else [ids]
|
||||||
max_model_len = self.config.max_len
|
max_model_len = self.config.max_len
|
||||||
|
|
||||||
all_fragments = []
|
all_fragments = []
|
||||||
fragment_origin_idx = []
|
fragment_origin_idx = []
|
||||||
|
|
||||||
for i, seq in enumerate(batch_ids):
|
for i, seq in enumerate(batch_ids):
|
||||||
if len(seq) > max_model_len:
|
if len(seq) > max_model_len:
|
||||||
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
|
fragments = [
|
||||||
|
seq[j : j + max_model_len]
|
||||||
|
for j in range(0, len(seq), max_model_len)
|
||||||
|
]
|
||||||
all_fragments.extend(fragments)
|
all_fragments.extend(fragments)
|
||||||
fragment_origin_idx.extend([i] * len(fragments))
|
fragment_origin_idx.extend([i] * len(fragments))
|
||||||
else:
|
else:
|
||||||
all_fragments.append(seq)
|
all_fragments.append(seq)
|
||||||
fragment_origin_idx.append(i)
|
fragment_origin_idx.append(i)
|
||||||
|
|
||||||
#if empty fragments
|
# if empty fragments
|
||||||
if not all_fragments or not ids:
|
if not all_fragments or not ids:
|
||||||
return [] if with_batch else torch.tensor([])
|
return [] if with_batch else torch.tensor([])
|
||||||
|
|
||||||
device = next(self.model.parameters()).device
|
device = next(self.model.parameters()).device
|
||||||
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
|
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
|
||||||
|
|
||||||
padded_ids = []
|
padded_ids = []
|
||||||
masks = []
|
masks = []
|
||||||
for seq in all_fragments:
|
for seq in all_fragments:
|
||||||
|
|
@ -179,24 +192,30 @@ class EmbeddingEncoderCore:
|
||||||
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
|
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
|
||||||
padded_ids.append(padded_seq)
|
padded_ids.append(padded_seq)
|
||||||
masks.append(mask)
|
masks.append(mask)
|
||||||
|
|
||||||
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
|
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
|
||||||
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
|
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
|
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
|
||||||
# [num_fragments, seq_len, hidden_size]
|
# [num_fragments, seq_len, hidden_size]
|
||||||
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
|
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
|
||||||
|
|
||||||
sentence_embs: List[Tensor] = []
|
sentence_embs: List[Tensor] = []
|
||||||
for i in range(len(batch_ids)):
|
for i in range(len(batch_ids)):
|
||||||
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
|
indices = [
|
||||||
|
idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i
|
||||||
|
]
|
||||||
if indices:
|
if indices:
|
||||||
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
|
sum_frags = torch.sum(
|
||||||
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
|
fragment_embs[indices, :, :], dim=1
|
||||||
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
|
) # [frags, hidden_size]
|
||||||
|
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(
|
||||||
|
1
|
||||||
|
) # [frags, 1]
|
||||||
|
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
|
||||||
sentence_embs.append(emb.flatten())
|
sentence_embs.append(emb.flatten())
|
||||||
|
|
||||||
if with_batch:
|
if with_batch:
|
||||||
return [emb.flatten() for emb in sentence_embs]
|
return [emb.flatten() for emb in sentence_embs]
|
||||||
else:
|
else:
|
||||||
|
|
@ -209,11 +228,11 @@ class EmbeddingEncoderCore:
|
||||||
|
|
||||||
class KVCacheManager:
|
class KVCacheManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ModelConfig,
|
config: ModelConfig,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
device: torch.device = "cuda",
|
device: torch.device = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
):
|
):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
@ -221,25 +240,41 @@ class KVCacheManager:
|
||||||
self.num_layers = config.n_layers
|
self.num_layers = config.n_layers
|
||||||
self.max_len = config.max_len
|
self.max_len = config.max_len
|
||||||
self.num_heads = config.n_kv_heads
|
self.num_heads = config.n_kv_heads
|
||||||
self.head_dim = config.dim //config.n_heads
|
self.head_dim = config.dim // config.n_heads
|
||||||
|
|
||||||
self._kv_cache: Tuple[Tensor, Tensor] = None
|
self._kv_cache: Tuple[Tensor, Tensor] = None
|
||||||
self._seq_mask: Tensor = None
|
self._seq_mask: Tensor = None
|
||||||
self._initialize()
|
self._initialize()
|
||||||
|
|
||||||
def _initialize(self):
|
def _initialize(self):
|
||||||
k_cache = torch.empty(
|
k_cache = torch.empty(
|
||||||
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
|
(
|
||||||
device=self.device, dtype=self.dtype
|
self.batch_size,
|
||||||
|
self.max_len,
|
||||||
|
self.num_layers,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
),
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
v_cache = torch.empty(
|
v_cache = torch.empty(
|
||||||
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
|
(
|
||||||
device=self.device, dtype=self.dtype
|
self.batch_size,
|
||||||
|
self.max_len,
|
||||||
|
self.num_layers,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
),
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self._kv_cache = (k_cache, v_cache)
|
self._kv_cache = (k_cache, v_cache)
|
||||||
self._seq_mask = torch.ones((self.batch_size, self.max_len), device=self.device, dtype=torch.bool)
|
self._seq_mask = torch.ones(
|
||||||
|
(self.batch_size, self.max_len), device=self.device, dtype=torch.bool
|
||||||
|
)
|
||||||
|
|
||||||
def update(self, active_mask: Tensor):
|
def update(self, active_mask: Tensor):
|
||||||
k_cache, v_cache = self._kv_cache
|
k_cache, v_cache = self._kv_cache
|
||||||
self._kv_cache = (k_cache[active_mask], v_cache[active_mask])
|
self._kv_cache = (k_cache[active_mask], v_cache[active_mask])
|
||||||
self._seq_mask = self._seq_mask[active_mask]
|
self._seq_mask = self._seq_mask[active_mask]
|
||||||
|
|
@ -250,14 +285,14 @@ class KVCacheManager:
|
||||||
self._seq_mask = None
|
self._seq_mask = None
|
||||||
else:
|
else:
|
||||||
self._initialize()
|
self._initialize()
|
||||||
|
|
||||||
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
|
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
|
||||||
batch_size, seq_len = input_ids.shape
|
batch_size, seq_len = input_ids.shape
|
||||||
bool_mask = (input_ids != pad_id)
|
bool_mask = input_ids != pad_id
|
||||||
self._seq_mask[: batch_size, : seq_len] = bool_mask
|
self._seq_mask[:batch_size, :seq_len] = bool_mask
|
||||||
|
|
||||||
def get_kvcache(self) -> Tuple[Tensor, Tensor]:
|
def get_kvcache(self) -> Tuple[Tensor, Tensor]:
|
||||||
return self._kv_cache
|
return self._kv_cache
|
||||||
|
|
||||||
def get_seq_mask(self) -> Tensor:
|
def get_seq_mask(self) -> Tensor:
|
||||||
return self._seq_mask
|
return self._seq_mask
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import List, Tuple, Union, Optional, Generator
|
from typing import List, Tuple, Union, Optional, Generator
|
||||||
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
|
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
|
||||||
from khaosz.config.param_config import ModelParameter
|
from khaosz.config.param_config import ModelParameter
|
||||||
|
|
@ -8,10 +8,11 @@ from khaosz.config.param_config import ModelParameter
|
||||||
|
|
||||||
HistoryType = List[Tuple[str, str]]
|
HistoryType = List[Tuple[str, str]]
|
||||||
|
|
||||||
|
|
||||||
def build_prompt(
|
def build_prompt(
|
||||||
query: str,
|
query: str,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
history: Optional[HistoryType] = None
|
history: Optional[HistoryType] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Build prompt in ChatML format for query and history.
|
Build prompt in ChatML format for query and history.
|
||||||
|
|
@ -42,17 +43,17 @@ def build_prompt(
|
||||||
|
|
||||||
|
|
||||||
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
|
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
|
||||||
"""
|
"""
|
||||||
Pad a list of sequences to a fixed length.
|
Pad a list of sequences to a fixed length.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ids_list (List[List[int]]): A list of sequences.
|
ids_list (List[List[int]]): A list of sequences.
|
||||||
max_ids_len (int): The maximum length of sequences.
|
max_ids_len (int): The maximum length of sequences.
|
||||||
pad_id (int): The id to pad sequences.
|
pad_id (int): The id to pad sequences.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[List[int]]: A list of padded sequences.
|
List[List[int]]: A list of padded sequences.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
max_ids_len = max(len(ids) for ids in ids_list)
|
max_ids_len = max(len(ids) for ids in ids_list)
|
||||||
new_ids_list = []
|
new_ids_list = []
|
||||||
|
|
@ -60,7 +61,7 @@ def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]
|
||||||
pad_len = max_ids_len - len(ids)
|
pad_len = max_ids_len - len(ids)
|
||||||
padded_seq = [pad_id] * pad_len + ids
|
padded_seq = [pad_id] * pad_len + ids
|
||||||
new_ids_list.append(padded_seq)
|
new_ids_list.append(padded_seq)
|
||||||
|
|
||||||
return new_ids_list, max_ids_len
|
return new_ids_list, max_ids_len
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -68,7 +69,7 @@ def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]
|
||||||
class GenerationRequest:
|
class GenerationRequest:
|
||||||
"""
|
"""
|
||||||
Request parameters for text generation.
|
Request parameters for text generation.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
top_k: Top-k sampling parameter.
|
top_k: Top-k sampling parameter.
|
||||||
top_p: Top-p (nucleus) sampling parameter.
|
top_p: Top-p (nucleus) sampling parameter.
|
||||||
|
|
@ -79,6 +80,7 @@ class GenerationRequest:
|
||||||
system_prompt: System prompt for the conversation.
|
system_prompt: System prompt for the conversation.
|
||||||
stream: Whether to use streaming generation.
|
stream: Whether to use streaming generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
top_k: int
|
top_k: int
|
||||||
top_p: float
|
top_p: float
|
||||||
temperature: float
|
temperature: float
|
||||||
|
|
@ -101,63 +103,66 @@ class GenerationRequest:
|
||||||
class LoopGenerator(GeneratorCore):
|
class LoopGenerator(GeneratorCore):
|
||||||
def __init__(self, parameter: ModelParameter):
|
def __init__(self, parameter: ModelParameter):
|
||||||
super().__init__(parameter)
|
super().__init__(parameter)
|
||||||
|
|
||||||
def generate(self, request: GenerationRequest) -> str:
|
def generate(self, request: GenerationRequest) -> str:
|
||||||
device = next(self.model.parameters()).device
|
device = next(self.model.parameters()).device
|
||||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||||
|
|
||||||
prompt = build_prompt(request.query, request.history)
|
prompt = build_prompt(request.query, request.history)
|
||||||
ids = self.tokenizer.encode(prompt)
|
ids = self.tokenizer.encode(prompt)
|
||||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||||
|
|
||||||
start_cache_pos = len(ids)
|
start_cache_pos = len(ids)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
kv_caches = cache_manager.get_kvcache()
|
kv_caches = cache_manager.get_kvcache()
|
||||||
|
|
||||||
ids = self.generate_loop(
|
ids = self.generate_loop(
|
||||||
input_ids,
|
input_ids,
|
||||||
ids,
|
ids,
|
||||||
request.temperature,
|
request.temperature,
|
||||||
request.top_k,
|
request.top_k,
|
||||||
request.top_p,
|
request.top_p,
|
||||||
kv_caches=kv_caches,
|
kv_caches=kv_caches,
|
||||||
)
|
)
|
||||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
class StreamGenerator(GeneratorCore):
|
class StreamGenerator(GeneratorCore):
|
||||||
def __init__(self, parameter: ModelParameter):
|
def __init__(self, parameter: ModelParameter):
|
||||||
super().__init__(parameter)
|
super().__init__(parameter)
|
||||||
|
|
||||||
def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
|
def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
|
||||||
device = next(self.model.parameters()).device
|
device = next(self.model.parameters()).device
|
||||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||||
|
|
||||||
prompt = build_prompt(request.query, request.history)
|
prompt = build_prompt(request.query, request.history)
|
||||||
ids = self.tokenizer.encode(prompt)
|
ids = self.tokenizer.encode(prompt)
|
||||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||||
|
|
||||||
start_cache_pos = len(ids)
|
start_cache_pos = len(ids)
|
||||||
cur_cache_pos = 0
|
cur_cache_pos = 0
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
kv_caches = cache_manager.get_kvcache()
|
kv_caches = cache_manager.get_kvcache()
|
||||||
|
|
||||||
for _ in range(len(ids), self.config.max_len):
|
for _ in range(len(ids), self.config.max_len):
|
||||||
next_token_id, cache_increase = self.generate_iterator(
|
next_token_id, cache_increase = self.generate_iterator(
|
||||||
input_ids, request.temperature, request.top_k, request.top_p,
|
input_ids,
|
||||||
kv_caches=kv_caches,
|
request.temperature,
|
||||||
start_pos=cur_cache_pos
|
request.top_k,
|
||||||
|
request.top_p,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
start_pos=cur_cache_pos,
|
||||||
)
|
)
|
||||||
|
|
||||||
input_ids = next_token_id
|
input_ids = next_token_id
|
||||||
ids.append(next_token_id.item())
|
ids.append(next_token_id.item())
|
||||||
cur_cache_pos += cache_increase
|
cur_cache_pos += cache_increase
|
||||||
|
|
||||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||||
yield response + "\n"
|
yield response + "\n"
|
||||||
break
|
break
|
||||||
|
|
@ -166,131 +171,140 @@ class StreamGenerator(GeneratorCore):
|
||||||
class BatchGenerator(GeneratorCore):
|
class BatchGenerator(GeneratorCore):
|
||||||
def __init__(self, parameter: ModelParameter):
|
def __init__(self, parameter: ModelParameter):
|
||||||
super().__init__(parameter)
|
super().__init__(parameter)
|
||||||
|
|
||||||
def generate(self, request: GenerationRequest) -> List[str]:
|
def generate(self, request: GenerationRequest) -> List[str]:
|
||||||
batch_size = len(request.query)
|
batch_size = len(request.query)
|
||||||
if request.history is None:
|
if request.history is None:
|
||||||
request.history = [[] for _ in range(batch_size)]
|
request.history = [[] for _ in range(batch_size)]
|
||||||
|
|
||||||
prompts = [build_prompt(query, history) for query, history in zip(request.query, request.history)]
|
prompts = [
|
||||||
|
build_prompt(query, history)
|
||||||
|
for query, history in zip(request.query, request.history)
|
||||||
|
]
|
||||||
|
|
||||||
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
|
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
|
||||||
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
|
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
|
||||||
|
|
||||||
device = next(self.model.parameters()).device
|
device = next(self.model.parameters()).device
|
||||||
cache_manager = KVCacheManager(self.config, batch_size, device=device)
|
cache_manager = KVCacheManager(self.config, batch_size, device=device)
|
||||||
|
|
||||||
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
|
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
|
||||||
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
|
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
|
||||||
activate_task_mask = [True] * batch_size
|
activate_task_mask = [True] * batch_size
|
||||||
|
|
||||||
start_cache_pos = max_ids_len
|
start_cache_pos = max_ids_len
|
||||||
cur_cache_pos = 0
|
cur_cache_pos = 0
|
||||||
|
|
||||||
while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0:
|
while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0:
|
||||||
kv_caches = cache_manager.get_kvcache()
|
kv_caches = cache_manager.get_kvcache()
|
||||||
attn_mask =cache_manager.get_seq_mask()
|
attn_mask = cache_manager.get_seq_mask()
|
||||||
|
|
||||||
next_token_id, cache_increase = self.generate_iterator(
|
next_token_id, cache_increase = self.generate_iterator(
|
||||||
input_tensor, request.temperature, request.top_k, request.top_p,
|
input_tensor,
|
||||||
attn_mask=attn_mask,
|
request.temperature,
|
||||||
kv_caches=kv_caches,
|
request.top_k,
|
||||||
start_pos=cur_cache_pos
|
request.top_p,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
start_pos=cur_cache_pos,
|
||||||
)
|
)
|
||||||
|
|
||||||
cur_cache_pos += cache_increase
|
cur_cache_pos += cache_increase
|
||||||
active_mask = []
|
active_mask = []
|
||||||
c_ids = 0
|
c_ids = 0
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
if activate_task_mask[i]:
|
if activate_task_mask[i]:
|
||||||
token = next_token_id[c_ids, :].item()
|
token = next_token_id[c_ids, :].item()
|
||||||
ids_list[i].append(token)
|
ids_list[i].append(token)
|
||||||
c_ids += 1
|
c_ids += 1
|
||||||
|
|
||||||
is_active = not token in self.tokenizer.stop_ids
|
is_active = not token in self.tokenizer.stop_ids
|
||||||
activate_task_mask[i] = is_active
|
activate_task_mask[i] = is_active
|
||||||
active_mask.append(is_active)
|
active_mask.append(is_active)
|
||||||
|
|
||||||
active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool)
|
active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool)
|
||||||
cache_manager.update(active_mask)
|
cache_manager.update(active_mask)
|
||||||
input_tensor = next_token_id[active_mask, :]
|
input_tensor = next_token_id[active_mask, :]
|
||||||
|
|
||||||
max_ids_len += 1
|
max_ids_len += 1
|
||||||
|
|
||||||
responses = [str()] * batch_size
|
responses = [str()] * batch_size
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
|
responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
|
||||||
request.history[i].append((request.query[i], responses[i]))
|
request.history[i].append((request.query[i], responses[i]))
|
||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingEncoder(EmbeddingEncoderCore):
|
class EmbeddingEncoder(EmbeddingEncoderCore):
|
||||||
def __init__(self, parameter: ModelParameter):
|
def __init__(self, parameter: ModelParameter):
|
||||||
super().__init__(parameter)
|
super().__init__(parameter)
|
||||||
|
|
||||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||||
return super().encode(sentence)
|
return super().encode(sentence)
|
||||||
|
|
||||||
|
|
||||||
class GeneratorFactory:
|
class GeneratorFactory:
|
||||||
"""Factory class for creating generator instances.
|
"""Factory class for creating generator instances.
|
||||||
|
|
||||||
Provides smart generator selection based on request characteristics:
|
Provides smart generator selection based on request characteristics:
|
||||||
- Streaming: Use StreamGenerator for streaming output
|
- Streaming: Use StreamGenerator for streaming output
|
||||||
- Batch: Use BatchGenerator when query is a list
|
- Batch: Use BatchGenerator when query is a list
|
||||||
- Single: Use LoopGenerator for single query non-streaming
|
- Single: Use LoopGenerator for single query non-streaming
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
generator = GeneratorFactory.create_generator(parameter, request)
|
generator = GeneratorFactory.create_generator(parameter, request)
|
||||||
result = generator.generate(request)
|
result = generator.generate(request)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_generator(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
|
def create_generator(
|
||||||
|
parameter: ModelParameter, request: GenerationRequest
|
||||||
|
) -> GeneratorCore:
|
||||||
"""Create a generator based on request characteristics.
|
"""Create a generator based on request characteristics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
parameter: Model parameters containing model, tokenizer, config
|
parameter: Model parameters containing model, tokenizer, config
|
||||||
request: Generation request with query, options, etc.
|
request: Generation request with query, options, etc.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Appropriate GeneratorCore subclass instance
|
Appropriate GeneratorCore subclass instance
|
||||||
"""
|
"""
|
||||||
# Streaming generation: check stream field first
|
# Streaming generation: check stream field first
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return StreamGenerator(parameter)
|
return StreamGenerator(parameter)
|
||||||
|
|
||||||
# Batch generation: query is a list of strings
|
# Batch generation: query is a list of strings
|
||||||
if isinstance(request.query, list):
|
if isinstance(request.query, list):
|
||||||
return BatchGenerator(parameter)
|
return BatchGenerator(parameter)
|
||||||
|
|
||||||
# Default: single query non-streaming
|
# Default: single query non-streaming
|
||||||
return LoopGenerator(parameter)
|
return LoopGenerator(parameter)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_encoder(parameter: ModelParameter) -> EmbeddingEncoderCore:
|
def create_encoder(parameter: ModelParameter) -> EmbeddingEncoderCore:
|
||||||
"""Create an embedding encoder instance.
|
"""Create an embedding encoder instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
parameter: Model parameters
|
parameter: Model parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
EmbeddingEncoderCore instance
|
EmbeddingEncoderCore instance
|
||||||
"""
|
"""
|
||||||
return EmbeddingEncoder(parameter)
|
return EmbeddingEncoder(parameter)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
|
def create(
|
||||||
|
cls, parameter: ModelParameter, request: GenerationRequest
|
||||||
|
) -> GeneratorCore:
|
||||||
"""Convenience method that delegates to create_generator.
|
"""Convenience method that delegates to create_generator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
parameter: Model parameters
|
parameter: Model parameters
|
||||||
request: Generation request
|
request: Generation request
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Generator instance
|
Generator instance
|
||||||
"""
|
"""
|
||||||
return cls.create_generator(parameter, request)
|
return cls.create_generator(parameter, request)
|
||||||
|
|
||||||
|
|
@ -1,17 +1,10 @@
|
||||||
from khaosz.model.module import (
|
from khaosz.model.module import (
|
||||||
Linear,
|
Linear,
|
||||||
RMSNorm,
|
RMSNorm,
|
||||||
MLP,
|
MLP,
|
||||||
GQA,
|
GQA,
|
||||||
DecoderBlock,
|
DecoderBlock,
|
||||||
)
|
)
|
||||||
from khaosz.model.transformer import Transformer
|
from khaosz.model.transformer import Transformer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["Linear", "RMSNorm", "MLP", "GQA", "DecoderBlock", "Transformer"]
|
||||||
"Linear",
|
|
||||||
"RMSNorm",
|
|
||||||
"MLP",
|
|
||||||
"GQA",
|
|
||||||
"DecoderBlock",
|
|
||||||
"Transformer"
|
|
||||||
]
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Repeat k times along the dimension for attention heads.
|
Repeat k times along the dimension for attention heads.
|
||||||
Args:
|
Args:
|
||||||
x (Tensor): The input tensor.
|
x (Tensor): The input tensor.
|
||||||
|
|
@ -15,7 +15,7 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: The repeated tensor.
|
Tensor: The repeated tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
bs, slen, n_heads, head_dim = x.shape
|
bs, slen, n_heads, head_dim = x.shape
|
||||||
if n_rep == 1:
|
if n_rep == 1:
|
||||||
return x
|
return x
|
||||||
|
|
@ -25,12 +25,13 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
.reshape(bs, slen, n_heads * n_rep, head_dim)
|
.reshape(bs, slen, n_heads * n_rep, head_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_rotary_emb(
|
def get_rotary_emb(
|
||||||
dim: int,
|
dim: int,
|
||||||
max_len: int,
|
max_len: int,
|
||||||
base: float = 10000,
|
base: float = 10000,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Get the rotary embedding for the given dimension and maximum length.
|
Get the rotary embedding for the given dimension and maximum length.
|
||||||
Args:
|
Args:
|
||||||
dim (int): The dimension of the input.
|
dim (int): The dimension of the input.
|
||||||
|
|
@ -46,6 +47,7 @@ def get_rotary_emb(
|
||||||
|
|
||||||
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
|
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Apply rotary embedding to the input tensor using cos/sin form.
|
Apply rotary embedding to the input tensor using cos/sin form.
|
||||||
|
|
@ -55,49 +57,49 @@ def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tens
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: The output tensor (rotated, same shape as input).
|
Tensor: The output tensor (rotated, same shape as input).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
cos, sin = rotary_emb
|
cos, sin = rotary_emb
|
||||||
|
|
||||||
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
|
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
|
||||||
sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
|
sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
|
||||||
|
|
||||||
x_real = x[..., 0::2] # [batch, seq_len, dim//2]
|
x_real = x[..., 0::2] # [batch, seq_len, dim//2]
|
||||||
x_imag = x[..., 1::2] # [batch, seq_len, dim//2]
|
x_imag = x[..., 1::2] # [batch, seq_len, dim//2]
|
||||||
|
|
||||||
x_real_rot = x_real * cos - x_imag * sin
|
x_real_rot = x_real * cos - x_imag * sin
|
||||||
x_imag_rot = x_real * sin + x_imag * cos
|
x_imag_rot = x_real * sin + x_imag * cos
|
||||||
|
|
||||||
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
|
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
|
||||||
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
|
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
|
||||||
|
|
||||||
return x_out.to(dtype)
|
return x_out.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
class RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, dim: int, max_len: int, base: int=10000):
|
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
self.base = base
|
self.base = base
|
||||||
self.max_len_cached = None
|
self.max_len_cached = None
|
||||||
self._set_rotary_buffer(self.max_len)
|
self._set_rotary_buffer(self.max_len)
|
||||||
|
|
||||||
def _set_rotary_buffer(self, max_len: int):
|
def _set_rotary_buffer(self, max_len: int):
|
||||||
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base)
|
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base)
|
||||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||||
self.max_len_cached = max_len
|
self.max_len_cached = max_len
|
||||||
|
|
||||||
def forward(self, x: Tensor, start_pos: int=0) -> Tuple[Tensor, Tensor]:
|
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
|
||||||
seq_len = x.size(1)
|
seq_len = x.size(1)
|
||||||
|
|
||||||
if self.max_len_cached < seq_len + start_pos:
|
if self.max_len_cached < seq_len + start_pos:
|
||||||
self._set_rotary_buffer(seq_len + start_pos)
|
self._set_rotary_buffer(seq_len + start_pos)
|
||||||
|
|
||||||
cos = self.cos_cached[start_pos : start_pos + seq_len]
|
cos = self.cos_cached[start_pos : start_pos + seq_len]
|
||||||
sin = self.sin_cached[start_pos : start_pos + seq_len]
|
sin = self.sin_cached[start_pos : start_pos + seq_len]
|
||||||
|
|
||||||
return (cos, sin)
|
return (cos, sin)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -115,43 +117,42 @@ class RMSNorm(nn.Module):
|
||||||
def __init__(self, dim, norm_eps):
|
def __init__(self, dim, norm_eps):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
self.normalized_shape = (dim, )
|
self.normalized_shape = (dim,)
|
||||||
self.norm_eps = norm_eps
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps)
|
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps)
|
||||||
return rms.to(x.dtype)
|
return rms.to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(self, dim: int, dim_feed_forward: int):
|
def __init__(self, dim: int, dim_feed_forward: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.up = Linear(dim, dim_feed_forward)
|
self.up = Linear(dim, dim_feed_forward)
|
||||||
self.gate = Linear(dim, dim_feed_forward)
|
self.gate = Linear(dim, dim_feed_forward)
|
||||||
self.down = Linear(dim_feed_forward, dim)
|
self.down = Linear(dim_feed_forward, dim)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
gated = self.up(x) * F.silu(self.gate(x))
|
gated = self.up(x) * F.silu(self.gate(x))
|
||||||
out = self.down(gated)
|
out = self.down(gated)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GQA(nn.Module):
|
class GQA(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
n_heads: int,
|
n_heads: int,
|
||||||
n_kv_heads: int,
|
n_kv_heads: int,
|
||||||
use_qk_norm: bool,
|
use_qk_norm: bool,
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int
|
layer_id: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert dim % n_heads == 0
|
assert dim % n_heads == 0
|
||||||
assert n_heads % n_kv_heads == 0
|
assert n_heads % n_kv_heads == 0
|
||||||
|
|
||||||
self.head_dim = dim // n_heads
|
self.head_dim = dim // n_heads
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
@ -165,11 +166,11 @@ class GQA(nn.Module):
|
||||||
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
|
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
|
||||||
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
|
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
|
||||||
self.o_proj = Linear(dim, dim)
|
self.o_proj = Linear(dim, dim)
|
||||||
|
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
self.q_norm = RMSNorm(self.head_dim, norm_eps)
|
self.q_norm = RMSNorm(self.head_dim, norm_eps)
|
||||||
self.k_norm = RMSNorm(self.head_dim, norm_eps)
|
self.k_norm = RMSNorm(self.head_dim, norm_eps)
|
||||||
|
|
||||||
if self.use_gated_attention:
|
if self.use_gated_attention:
|
||||||
self.gate = Linear(dim, dim)
|
self.gate = Linear(dim, dim)
|
||||||
|
|
||||||
|
|
@ -177,14 +178,14 @@ class GQA(nn.Module):
|
||||||
batch_size, seq_len, _ = x.shape
|
batch_size, seq_len, _ = x.shape
|
||||||
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
|
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
rotary_emb: Tuple[Tensor, Tensor],
|
rotary_emb: Tuple[Tensor, Tensor],
|
||||||
mask: Tensor = None,
|
mask: Tensor = None,
|
||||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
start_pos: int = 0
|
start_pos: int = 0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
bsz, seq_len, _ = x.size()
|
bsz, seq_len, _ = x.size()
|
||||||
is_causal = mask is None
|
is_causal = mask is None
|
||||||
|
|
@ -194,31 +195,36 @@ class GQA(nn.Module):
|
||||||
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
||||||
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
||||||
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
|
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
|
||||||
|
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
q, k = self.q_norm(q), self.k_norm(k)
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
k_cache, v_cache = kv_cache
|
k_cache, v_cache = kv_cache
|
||||||
|
|
||||||
# copy to cache
|
# copy to cache
|
||||||
k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k
|
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
|
||||||
v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v
|
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
|
||||||
|
|
||||||
# get cache
|
# get cache
|
||||||
k = k_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
||||||
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
||||||
|
|
||||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||||
|
|
||||||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
||||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||||
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
|
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
|
||||||
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal).permute(0, 2, 1, 3).contiguous().flatten(2)
|
sdqa_out = (
|
||||||
|
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.contiguous()
|
||||||
|
.flatten(2)
|
||||||
|
)
|
||||||
|
|
||||||
if self.use_gated_attention:
|
if self.use_gated_attention:
|
||||||
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
|
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
|
||||||
|
|
||||||
out = self.o_proj(sdqa_out)
|
out = self.o_proj(sdqa_out)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
@ -227,15 +233,15 @@ class GQA(nn.Module):
|
||||||
class MLA(nn.Module):
|
class MLA(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
n_heads: int,
|
n_heads: int,
|
||||||
n_kv_heads: int,
|
n_kv_heads: int,
|
||||||
kv_lora_rank: int,
|
kv_lora_rank: int,
|
||||||
qk_nope_head_dim: int,
|
qk_nope_head_dim: int,
|
||||||
qk_rope_head_dim: int,
|
qk_rope_head_dim: int,
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int
|
layer_id: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
@ -252,45 +258,46 @@ class MLA(nn.Module):
|
||||||
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
||||||
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||||
self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps)
|
self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps)
|
||||||
|
|
||||||
# KV (k_nope, k_rope, v)
|
# KV (k_nope, k_rope, v)
|
||||||
self.kv_b_proj = Linear(
|
self.kv_b_proj = Linear(
|
||||||
kv_lora_rank,
|
kv_lora_rank,
|
||||||
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = Linear(dim, dim, bias=False)
|
self.o_proj = Linear(dim, dim, bias=False)
|
||||||
|
|
||||||
if use_gated_attention:
|
if use_gated_attention:
|
||||||
self.gate = Linear(dim, dim, bias=False)
|
self.gate = Linear(dim, dim, bias=False)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
rotary_emb: Tuple[Tensor, Tensor],
|
rotary_emb: Tuple[Tensor, Tensor],
|
||||||
mask: Tensor = None,
|
mask: Tensor = None,
|
||||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
start_pos: int = 0
|
start_pos: int = 0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
bsz, seq_len, _ = x.size()
|
bsz, seq_len, _ = x.size()
|
||||||
is_causal = mask is None
|
is_causal = mask is None
|
||||||
|
|
||||||
q = self.q_proj(x)
|
q = self.q_proj(x)
|
||||||
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
|
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
|
||||||
|
|
||||||
kv_compressed = self.kv_a_proj(x)
|
kv_compressed = self.kv_a_proj(x)
|
||||||
kv_compressed = self.kv_norm(kv_compressed)
|
kv_compressed = self.kv_norm(kv_compressed)
|
||||||
|
|
||||||
kv = self.kv_b_proj(kv_compressed)
|
kv = self.kv_b_proj(kv_compressed)
|
||||||
kv = kv.view(bsz, seq_len, self.n_kv_heads, -1)
|
kv = kv.view(bsz, seq_len, self.n_kv_heads, -1)
|
||||||
|
|
||||||
k_nope, k_rope, v = torch.split(
|
k_nope, k_rope, v = torch.split(
|
||||||
kv,
|
kv, [self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], dim=-1
|
||||||
[self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim],
|
)
|
||||||
dim=-1
|
|
||||||
|
q_nope, q_rope = (
|
||||||
|
q[..., : self.qk_nope_head_dim],
|
||||||
|
q[..., self.qk_rope_head_dim :],
|
||||||
)
|
)
|
||||||
|
|
||||||
q_nope, q_rope = q[..., :self.qk_nope_head_dim], q[..., self.qk_rope_head_dim:]
|
|
||||||
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
||||||
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
||||||
|
|
||||||
|
|
@ -299,41 +306,48 @@ class MLA(nn.Module):
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
k_cache, v_cache = kv_cache
|
k_cache, v_cache = kv_cache
|
||||||
k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k
|
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
|
||||||
v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v
|
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
|
||||||
k = k_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
||||||
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
||||||
|
|
||||||
q = q.permute(0, 2, 1, 3)
|
q = q.permute(0, 2, 1, 3)
|
||||||
k = k.permute(0, 2, 1, 3)
|
k = k.permute(0, 2, 1, 3)
|
||||||
v = v.permute(0, 2, 1, 3)
|
v = v.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
||||||
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
|
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
|
||||||
|
|
||||||
if self.use_gated_attention:
|
if self.use_gated_attention:
|
||||||
attn_out = attn_out * F.sigmoid(self.gate(x))
|
attn_out = attn_out * F.sigmoid(self.gate(x))
|
||||||
|
|
||||||
out = self.o_proj(attn_out)
|
out = self.o_proj(attn_out)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class DecoderBlock(nn.Module):
|
class DecoderBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
n_heads: int,
|
n_heads: int,
|
||||||
dim_ffn: int,
|
dim_ffn: int,
|
||||||
n_kv_heads: int,
|
n_kv_heads: int,
|
||||||
norm_eps: int,
|
norm_eps: int,
|
||||||
use_qk_norm: bool,
|
use_qk_norm: bool,
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int
|
layer_id: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attention = GQA(dim, n_heads, n_kv_heads,
|
self.attention = GQA(
|
||||||
use_qk_norm, norm_eps, use_gated_attention, layer_id)
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
use_qk_norm,
|
||||||
|
norm_eps,
|
||||||
|
use_gated_attention,
|
||||||
|
layer_id,
|
||||||
|
)
|
||||||
self.input_norm = RMSNorm(dim, norm_eps)
|
self.input_norm = RMSNorm(dim, norm_eps)
|
||||||
self.mlp = MLP(dim, dim_ffn)
|
self.mlp = MLP(dim, dim_ffn)
|
||||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||||
|
|
@ -341,24 +355,20 @@ class DecoderBlock(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
rotary_emb: Tuple[Tensor, Tensor],
|
rotary_emb: Tuple[Tensor, Tensor],
|
||||||
attention_mask: Optional[Tensor] = None,
|
attention_mask: Optional[Tensor] = None,
|
||||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
start_pos: int = 0
|
start_pos: int = 0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
# attention
|
# attention
|
||||||
attn_output = self.attention(
|
attn_output = self.attention(
|
||||||
self.input_norm(x),
|
self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos
|
||||||
rotary_emb,
|
|
||||||
attention_mask,
|
|
||||||
kv_cache,
|
|
||||||
start_pos
|
|
||||||
)
|
)
|
||||||
x = attn_output + x
|
x = attn_output + x
|
||||||
|
|
||||||
# feed forward
|
# feed forward
|
||||||
x = self.mlp(self.post_attention_norm(x)) + x
|
x = self.mlp(self.post_attention_norm(x)) + x
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -366,6 +376,6 @@ class Embedding(nn.Module):
|
||||||
def __init__(self, vocab_size: int, embedding_dim: int):
|
def __init__(self, vocab_size: int, embedding_dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return F.embedding(x, self.weight)
|
return F.embedding(x, self.weight)
|
||||||
|
|
|
||||||
|
|
@ -4,15 +4,21 @@ import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Any, Mapping, Optional, Tuple
|
from typing import Any, Mapping, Optional, Tuple
|
||||||
from khaosz.config.model_config import ModelConfig
|
from khaosz.config.model_config import ModelConfig
|
||||||
from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, RotaryEmbedding
|
from khaosz.model.module import (
|
||||||
|
Embedding,
|
||||||
|
DecoderBlock,
|
||||||
|
Linear,
|
||||||
|
RMSNorm,
|
||||||
|
RotaryEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def process_attention_mask(
|
def process_attention_mask(
|
||||||
seq_mask: Tensor,
|
seq_mask: Tensor,
|
||||||
input_tensor: Tensor,
|
input_tensor: Tensor,
|
||||||
start_pos: int = 0,
|
start_pos: int = 0,
|
||||||
is_causal: bool = False,
|
is_causal: bool = False,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Create attention mask for GQA
|
Create attention mask for GQA
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -26,32 +32,36 @@ def process_attention_mask(
|
||||||
device = input_tensor.device
|
device = input_tensor.device
|
||||||
dtype = input_tensor.dtype
|
dtype = input_tensor.dtype
|
||||||
seq_len = input_tensor.size(1)
|
seq_len = input_tensor.size(1)
|
||||||
|
|
||||||
if seq_mask is None:
|
if seq_mask is None:
|
||||||
if start_pos != 0:
|
if start_pos != 0:
|
||||||
# for single prompt chat
|
# for single prompt chat
|
||||||
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if seq_mask.dim() > 2:
|
if seq_mask.dim() > 2:
|
||||||
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
|
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
|
||||||
# if ndim > 2, it's 4D tensor
|
# if ndim > 2, it's 4D tensor
|
||||||
return seq_mask
|
return seq_mask
|
||||||
|
|
||||||
batch_size = seq_mask.size(0)
|
batch_size = seq_mask.size(0)
|
||||||
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
|
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
|
||||||
# (bsz, start_pos + seq_len)
|
# (bsz, start_pos + seq_len)
|
||||||
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
|
expanded_mask = seq_mask.unsqueeze(1).expand(
|
||||||
|
batch_size, seq_len, start_pos + seq_len
|
||||||
|
)
|
||||||
# (bsz, seq_len, start_pos + seq_len)
|
# (bsz, seq_len, start_pos + seq_len)
|
||||||
|
|
||||||
if is_causal:
|
if is_causal:
|
||||||
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
|
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
|
||||||
|
|
||||||
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
|
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
|
||||||
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
|
attention_mask = attention_mask.masked_fill_(
|
||||||
|
~expanded_mask, -torch.finfo(dtype).max / 2
|
||||||
|
).unsqueeze(1)
|
||||||
# (bsz, 1, seq_len, seq_len + start_pos)
|
# (bsz, 1, seq_len, seq_len + start_pos)
|
||||||
|
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -59,26 +69,38 @@ class Transformer(nn.Module):
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: ModelConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rotary_embeding = RotaryEmbedding(config.dim // config.n_heads, config.max_len)
|
self.rotary_embeding = RotaryEmbedding(
|
||||||
|
config.dim // config.n_heads, config.max_len
|
||||||
|
)
|
||||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||||
|
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList(
|
||||||
DecoderBlock(config.dim, config.n_heads, config.dim_ffn, config.n_kv_heads,
|
[
|
||||||
config.norm_eps, config.use_qk_norm, config.use_gated_attention, layer_id)
|
DecoderBlock(
|
||||||
for layer_id in range(config.n_layers)
|
config.dim,
|
||||||
])
|
config.n_heads,
|
||||||
|
config.dim_ffn,
|
||||||
|
config.n_kv_heads,
|
||||||
|
config.norm_eps,
|
||||||
|
config.use_qk_norm,
|
||||||
|
config.use_gated_attention,
|
||||||
|
layer_id,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.n_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||||
self.lm_head = Linear(config.dim, config.vocab_size)
|
self.lm_head = Linear(config.dim, config.vocab_size)
|
||||||
|
|
||||||
if self.config.tie_weight == True:
|
if self.config.tie_weight == True:
|
||||||
self.lm_head.weight = self.embed_tokens.weight
|
self.lm_head.weight = self.embed_tokens.weight
|
||||||
|
|
||||||
self._init_parameters()
|
self._init_parameters()
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||||
lm_head_key = 'lm_head.weight'
|
lm_head_key = "lm_head.weight"
|
||||||
embed_key = 'embed_tokens.weight'
|
embed_key = "embed_tokens.weight"
|
||||||
|
|
||||||
if self.config.tie_weight == True:
|
if self.config.tie_weight == True:
|
||||||
# same tensor
|
# same tensor
|
||||||
|
|
@ -87,48 +109,44 @@ class Transformer(nn.Module):
|
||||||
if lm_head_key not in state_dict and embed_key in state_dict:
|
if lm_head_key not in state_dict and embed_key in state_dict:
|
||||||
# use clone to avoid sharing the same tensor
|
# use clone to avoid sharing the same tensor
|
||||||
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
|
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
|
||||||
|
|
||||||
return super().load_state_dict(state_dict, strict, assign)
|
return super().load_state_dict(state_dict, strict, assign)
|
||||||
|
|
||||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||||
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
state_dict = super().state_dict(
|
||||||
|
destination=destination, prefix=prefix, keep_vars=keep_vars
|
||||||
|
)
|
||||||
|
|
||||||
if self.config.tie_weight == True:
|
if self.config.tie_weight == True:
|
||||||
lm_head_key = prefix + 'lm_head.weight'
|
lm_head_key = prefix + "lm_head.weight"
|
||||||
if lm_head_key in state_dict:
|
if lm_head_key in state_dict:
|
||||||
del state_dict[lm_head_key]
|
del state_dict[lm_head_key]
|
||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def _init_parameters(self):
|
def _init_parameters(self):
|
||||||
for param in self.parameters():
|
for param in self.parameters():
|
||||||
if param.dim() > 1:
|
if param.dim() > 1:
|
||||||
nn.init.normal_(param, mean=0.0, std=0.006)
|
nn.init.normal_(param, mean=0.0, std=0.006)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
input_mask: Optional[Tensor]=None,
|
input_mask: Optional[Tensor] = None,
|
||||||
persistent_key_values: Optional[Tuple[Tensor, Tensor]]=None,
|
persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
start_pos: int = 0
|
start_pos: int = 0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
assert input_ids.ndim == 2
|
assert input_ids.ndim == 2
|
||||||
|
|
||||||
x = self.embed_tokens(input_ids)
|
x = self.embed_tokens(input_ids)
|
||||||
rotary_emb = self.rotary_embeding(x, start_pos)
|
rotary_emb = self.rotary_embeding(x, start_pos)
|
||||||
|
|
||||||
attn_mask = process_attention_mask(
|
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
|
||||||
input_mask, x, start_pos, is_causal=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos)
|
x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos)
|
||||||
|
|
||||||
hidden_states = self.norm(x)
|
hidden_states = self.norm(x)
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
return {
|
return {"logits": logits, "hidden_states": hidden_states}
|
||||||
"logits": logits,
|
|
||||||
"hidden_states": hidden_states
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,27 +1,21 @@
|
||||||
from khaosz.parallel.setup import (
|
from khaosz.parallel.setup import (
|
||||||
get_world_size,
|
get_world_size,
|
||||||
get_rank,
|
get_rank,
|
||||||
get_current_device,
|
get_current_device,
|
||||||
|
|
||||||
only_on_rank,
|
only_on_rank,
|
||||||
setup_parallel,
|
setup_parallel,
|
||||||
spawn_parallel_fn
|
spawn_parallel_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khaosz.parallel.module import (
|
from khaosz.parallel.module import RowParallelLinear, ColumnParallelLinear
|
||||||
RowParallelLinear,
|
|
||||||
ColumnParallelLinear
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_world_size",
|
"get_world_size",
|
||||||
"get_rank",
|
"get_rank",
|
||||||
"get_current_device",
|
"get_current_device",
|
||||||
|
|
||||||
"only_on_rank",
|
"only_on_rank",
|
||||||
"setup_parallel",
|
"setup_parallel",
|
||||||
"spawn_parallel_fn",
|
"spawn_parallel_fn",
|
||||||
|
|
||||||
"RowParallelLinear",
|
"RowParallelLinear",
|
||||||
"ColumnParallelLinear"
|
"ColumnParallelLinear",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -17,91 +17,99 @@ class ParallelModel(nn.Module):
|
||||||
|
|
||||||
class RowParallelLinear(ParallelModel):
|
class RowParallelLinear(ParallelModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
process_group: dist.ProcessGroup,
|
process_group: dist.ProcessGroup,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
reduce_results: bool = True
|
reduce_results: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(process_group)
|
super().__init__(process_group)
|
||||||
|
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.in_features_per_rank = in_features // self.world_size
|
self.in_features_per_rank = in_features // self.world_size
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
|
|
||||||
if in_features % self.world_size != 0:
|
if in_features % self.world_size != 0:
|
||||||
raise ValueError(f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}")
|
raise ValueError(
|
||||||
|
f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}"
|
||||||
|
)
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank))
|
self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank))
|
||||||
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
|
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
output = F.linear(input, self.weight)
|
output = F.linear(input, self.weight)
|
||||||
|
|
||||||
if self.reduce_results:
|
if self.reduce_results:
|
||||||
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
||||||
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
output += self.bias
|
output += self.bias
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
||||||
full_weight = state_dict.get('weight')
|
full_weight = state_dict.get("weight")
|
||||||
full_bias = state_dict.get('bias')
|
full_bias = state_dict.get("bias")
|
||||||
|
|
||||||
start_idx = self.rank * self.in_features_per_rank
|
start_idx = self.rank * self.in_features_per_rank
|
||||||
end_idx = start_idx + self.in_features_per_rank
|
end_idx = start_idx + self.in_features_per_rank
|
||||||
weight_slice = full_weight[:, start_idx:end_idx]
|
weight_slice = full_weight[:, start_idx:end_idx]
|
||||||
self.weight.data.copy_(weight_slice)
|
self.weight.data.copy_(weight_slice)
|
||||||
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias.data.copy_(full_bias)
|
self.bias.data.copy_(full_bias)
|
||||||
|
|
||||||
|
|
||||||
class ColumnParallelLinear(ParallelModel):
|
class ColumnParallelLinear(ParallelModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
process_group: dist.ProcessGroup,
|
process_group: dist.ProcessGroup,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
gather_results: bool = True
|
gather_results: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(process_group)
|
super().__init__(process_group)
|
||||||
|
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.out_features_per_rank = out_features // self.world_size
|
self.out_features_per_rank = out_features // self.world_size
|
||||||
self.gather_results = gather_results
|
self.gather_results = gather_results
|
||||||
|
|
||||||
if out_features % self.world_size != 0:
|
if out_features % self.world_size != 0:
|
||||||
raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}")
|
raise ValueError(
|
||||||
|
f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
torch.empty(self.out_features_per_rank, self.in_features)
|
||||||
|
)
|
||||||
|
self.bias = (
|
||||||
|
nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None
|
||||||
|
)
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.empty(self.out_features_per_rank, self.in_features))
|
|
||||||
self.bias = nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None
|
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
output = F.linear(input, self.weight, self.bias)
|
output = F.linear(input, self.weight, self.bias)
|
||||||
|
|
||||||
if self.gather_results:
|
if self.gather_results:
|
||||||
output_list = [torch.empty_like(output) for _ in range(self.world_size)]
|
output_list = [torch.empty_like(output) for _ in range(self.world_size)]
|
||||||
dist.all_gather(output_list, output, group=self.process_group)
|
dist.all_gather(output_list, output, group=self.process_group)
|
||||||
output = torch.cat(output_list, dim=-1)
|
output = torch.cat(output_list, dim=-1)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
||||||
full_weight = state_dict.get('weight')
|
full_weight = state_dict.get("weight")
|
||||||
full_bias = state_dict.get('bias')
|
full_bias = state_dict.get("bias")
|
||||||
|
|
||||||
start_idx = self.rank * self.out_features_per_rank
|
start_idx = self.rank * self.out_features_per_rank
|
||||||
end_idx = start_idx + self.out_features_per_rank
|
end_idx = start_idx + self.out_features_per_rank
|
||||||
weight_slice = full_weight[start_idx:end_idx, :]
|
weight_slice = full_weight[start_idx:end_idx, :]
|
||||||
self.weight.data.copy_(weight_slice)
|
self.weight.data.copy_(weight_slice)
|
||||||
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
bias_slice = full_bias[start_idx:end_idx]
|
bias_slice = full_bias[start_idx:end_idx]
|
||||||
self.bias.data.copy_(bias_slice)
|
self.bias.data.copy_(bias_slice)
|
||||||
|
|
|
||||||
|
|
@ -11,73 +11,74 @@ from typing import Callable, List, Optional
|
||||||
def get_current_device():
|
def get_current_device():
|
||||||
return os.environ["LOCAL_DEVICE"]
|
return os.environ["LOCAL_DEVICE"]
|
||||||
|
|
||||||
|
|
||||||
def get_world_size() -> int:
|
def get_world_size() -> int:
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
return dist.get_world_size()
|
return dist.get_world_size()
|
||||||
else:
|
else:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
def get_rank() -> int:
|
def get_rank() -> int:
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
return dist.get_rank()
|
return dist.get_rank()
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def setup_parallel(
|
def setup_parallel(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
backend: str = "nccl",
|
backend: str = "nccl",
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500",
|
master_port: str = "29500",
|
||||||
device_type: str = "cuda",
|
device_type: str = "cuda",
|
||||||
device_ids: Optional[List[int]] = None
|
device_ids: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
yield dist.group.WORLD
|
yield dist.group.WORLD
|
||||||
return
|
return
|
||||||
|
|
||||||
if world_size <= 1:
|
if world_size <= 1:
|
||||||
yield None
|
yield None
|
||||||
return
|
return
|
||||||
|
|
||||||
if device_ids is None:
|
if device_ids is None:
|
||||||
device_ids = [i for i in range(world_size)]
|
device_ids = [i for i in range(world_size)]
|
||||||
|
|
||||||
rank = device_ids[rank % len(device_ids)]
|
rank = device_ids[rank % len(device_ids)]
|
||||||
device_id = torch.device(device_type, device_ids[rank])
|
device_id = torch.device(device_type, device_ids[rank])
|
||||||
|
|
||||||
os.environ['MASTER_ADDR'] = master_addr
|
os.environ["MASTER_ADDR"] = master_addr
|
||||||
os.environ['MASTER_PORT'] = master_port
|
os.environ["MASTER_PORT"] = master_port
|
||||||
|
|
||||||
os.environ['LOCAL_RANK'] = str(rank)
|
os.environ["LOCAL_RANK"] = str(rank)
|
||||||
os.environ['WORLD_SIZE'] = str(world_size)
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
|
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
rank=rank,
|
rank=rank, world_size=world_size, backend=backend, device_id=device_id
|
||||||
world_size=world_size,
|
|
||||||
backend=backend,
|
|
||||||
device_id=device_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if backend == "nccl" and torch.cuda.is_available():
|
if backend == "nccl" and torch.cuda.is_available():
|
||||||
torch.cuda.set_device(device_id)
|
torch.cuda.set_device(device_id)
|
||||||
elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available():
|
elif backend == "ccl" and hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
torch.xpu.set_device(device_id)
|
torch.xpu.set_device(device_id)
|
||||||
|
|
||||||
yield dist.group.WORLD
|
yield dist.group.WORLD
|
||||||
finally:
|
finally:
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def only_on_rank(rank, sync=False):
|
def only_on_rank(rank, sync=False):
|
||||||
"""
|
"""
|
||||||
decorator to run a function only on a specific rank.
|
decorator to run a function only on a specific rank.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
|
|
@ -89,67 +90,81 @@ def only_on_rank(rank, sync=False):
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
return ret_args
|
return ret_args
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def wrapper_spawn_func(
|
def wrapper_spawn_func(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
backend: str,
|
backend: str,
|
||||||
master_addr: str,
|
master_addr: str,
|
||||||
master_port: str,
|
master_port: str,
|
||||||
device_type: str,
|
device_type: str,
|
||||||
device_ids: List[int],
|
device_ids: List[int],
|
||||||
func: Callable,
|
func: Callable,
|
||||||
kwargs: dict
|
kwargs: dict,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
with setup_parallel(
|
with setup_parallel(
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
backend=backend,
|
backend=backend,
|
||||||
master_addr=master_addr,
|
master_addr=master_addr,
|
||||||
master_port=master_port,
|
master_port=master_port,
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
device_ids=device_ids
|
device_ids=device_ids,
|
||||||
):
|
):
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in rank {rank}: {e}")
|
print(f"Error in rank {rank}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def spawn_parallel_fn(
|
def spawn_parallel_fn(
|
||||||
func: Callable,
|
func: Callable,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
backend: str = "nccl",
|
backend: str = "nccl",
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500",
|
master_port: str = "29500",
|
||||||
device_type: str = "cuda",
|
device_type: str = "cuda",
|
||||||
device_ids: Optional[List[int]] = None,
|
device_ids: Optional[List[int]] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
# clear environment variables
|
# clear environment variables
|
||||||
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK', 'LOCAL_DEVICE']:
|
for key in [
|
||||||
|
"MASTER_ADDR",
|
||||||
|
"MASTER_PORT",
|
||||||
|
"RANK",
|
||||||
|
"WORLD_SIZE",
|
||||||
|
"LOCAL_RANK",
|
||||||
|
"LOCAL_DEVICE",
|
||||||
|
]:
|
||||||
if key in os.environ:
|
if key in os.environ:
|
||||||
del os.environ[key]
|
del os.environ[key]
|
||||||
|
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
device_ids = device_ids or [0]
|
device_ids = device_ids or [0]
|
||||||
device_id = torch.device(device_type, device_ids[0])
|
device_id = torch.device(device_type, device_ids[0])
|
||||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
|
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
return
|
return
|
||||||
|
|
||||||
wrapper_spawn_func_args = (world_size, backend, master_addr, master_port,
|
wrapper_spawn_func_args = (
|
||||||
device_type, device_ids, func, kwargs)
|
world_size,
|
||||||
|
backend,
|
||||||
|
master_addr,
|
||||||
|
master_port,
|
||||||
|
device_type,
|
||||||
|
device_ids,
|
||||||
|
func,
|
||||||
|
kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
mp.spawn(
|
mp.spawn(
|
||||||
wrapper_spawn_func,
|
wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True
|
||||||
nprocs=world_size,
|
)
|
||||||
args=wrapper_spawn_func_args,
|
|
||||||
join=True
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -14,15 +14,12 @@ from khaosz.trainer.train_callback import (
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Main trainer
|
# Main trainer
|
||||||
"Trainer",
|
"Trainer",
|
||||||
|
|
||||||
# Strategy factory
|
# Strategy factory
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
"BaseStrategy",
|
"BaseStrategy",
|
||||||
|
|
||||||
# Scheduler factory
|
# Scheduler factory
|
||||||
"SchedulerFactory",
|
"SchedulerFactory",
|
||||||
"BaseScheduler",
|
"BaseScheduler",
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
"TrainCallback",
|
"TrainCallback",
|
||||||
"GradientClippingCallback",
|
"GradientClippingCallback",
|
||||||
|
|
@ -30,4 +27,4 @@ __all__ = [
|
||||||
"CheckpointCallback",
|
"CheckpointCallback",
|
||||||
"ProgressBarCallback",
|
"ProgressBarCallback",
|
||||||
"MetricLoggerCallback",
|
"MetricLoggerCallback",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
||||||
""" Compute gradient norm for each parameter in the model. """
|
"""Compute gradient norm for each parameter in the model."""
|
||||||
norms = {}
|
norms = {}
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
norms[name] = 0.0
|
norms[name] = 0.0
|
||||||
|
|
@ -11,8 +12,9 @@ def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
||||||
norms[name] = norm
|
norms[name] = norm
|
||||||
return norms
|
return norms
|
||||||
|
|
||||||
|
|
||||||
def grad_std(model: nn.Module) -> Dict[str, float]:
|
def grad_std(model: nn.Module) -> Dict[str, float]:
|
||||||
""" Compute standard deviation of gradients for each parameter. """
|
"""Compute standard deviation of gradients for each parameter."""
|
||||||
stds = {}
|
stds = {}
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
stds[name] = 0.0
|
stds[name] = 0.0
|
||||||
|
|
@ -21,41 +23,45 @@ def grad_std(model: nn.Module) -> Dict[str, float]:
|
||||||
stds[name] = std
|
stds[name] = std
|
||||||
return stds
|
return stds
|
||||||
|
|
||||||
|
|
||||||
def grad_max(model: nn.Module) -> Dict[str, float]:
|
def grad_max(model: nn.Module) -> Dict[str, float]:
|
||||||
""" Find the maximum absolute gradient value for each parameter. """
|
"""Find the maximum absolute gradient value for each parameter."""
|
||||||
max_vals = {}
|
max_vals = {}
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
max_vals[name] = -float('inf')
|
max_vals[name] = -float("inf")
|
||||||
if param.grad:
|
if param.grad:
|
||||||
max_val = param.grad.data.max().item()
|
max_val = param.grad.data.max().item()
|
||||||
max_vals[name] = max_val
|
max_vals[name] = max_val
|
||||||
|
|
||||||
return max_vals
|
return max_vals
|
||||||
|
|
||||||
|
|
||||||
def grad_min(model: nn.Module) -> Dict[str, float]:
|
def grad_min(model: nn.Module) -> Dict[str, float]:
|
||||||
""" Find the minimum absolute gradient value for each parameter. """
|
"""Find the minimum absolute gradient value for each parameter."""
|
||||||
min_vals = {}
|
min_vals = {}
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
min_vals[name] = float('inf')
|
min_vals[name] = float("inf")
|
||||||
if param.grad:
|
if param.grad:
|
||||||
min_val = param.grad.data.min().item()
|
min_val = param.grad.data.min().item()
|
||||||
min_vals[name] = min_val
|
min_vals[name] = min_val
|
||||||
|
|
||||||
return min_vals
|
return min_vals
|
||||||
|
|
||||||
|
|
||||||
def grad_mean(model: nn.Module) -> Dict[str, float]:
|
def grad_mean(model: nn.Module) -> Dict[str, float]:
|
||||||
""" Compute mean of gradients for each parameter. """
|
"""Compute mean of gradients for each parameter."""
|
||||||
means = {}
|
means = {}
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
means[name] = 0.0
|
means[name] = 0.0
|
||||||
if param.grad:
|
if param.grad:
|
||||||
mean = param.grad.data.mean().item()
|
mean = param.grad.data.mean().item()
|
||||||
means[name] = mean
|
means[name] = mean
|
||||||
|
|
||||||
return means
|
return means
|
||||||
|
|
||||||
|
|
||||||
def grad_nan_num(model: nn.Module) -> Dict[str, int]:
|
def grad_nan_num(model: nn.Module) -> Dict[str, int]:
|
||||||
""" Count the number of NaNs in gradients for each parameter. """
|
"""Count the number of NaNs in gradients for each parameter."""
|
||||||
nan_nums = {}
|
nan_nums = {}
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
nan_nums[name] = 0
|
nan_nums[name] = 0
|
||||||
|
|
@ -64,26 +70,34 @@ def grad_nan_num(model: nn.Module) -> Dict[str, int]:
|
||||||
nan_nums[name] = nan_num
|
nan_nums[name] = nan_num
|
||||||
return nan_nums
|
return nan_nums
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_loss(ctx):
|
def ctx_get_loss(ctx):
|
||||||
return ctx.loss
|
return ctx.loss
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_lr(ctx):
|
def ctx_get_lr(ctx):
|
||||||
return ctx.optimizer.param_groups[-1]['lr']
|
return ctx.optimizer.param_groups[-1]["lr"]
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_norm(ctx):
|
def ctx_get_grad_norm(ctx):
|
||||||
return grad_norm(ctx.model)
|
return grad_norm(ctx.model)
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_std(ctx):
|
def ctx_get_grad_std(ctx):
|
||||||
return grad_std(ctx.model)
|
return grad_std(ctx.model)
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_max(ctx):
|
def ctx_get_grad_max(ctx):
|
||||||
return grad_max(ctx.model)
|
return grad_max(ctx.model)
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_min(ctx):
|
def ctx_get_grad_min(ctx):
|
||||||
return grad_min(ctx.model)
|
return grad_min(ctx.model)
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_mean(ctx):
|
def ctx_get_grad_mean(ctx):
|
||||||
return grad_mean(ctx.model)
|
return grad_mean(ctx.model)
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_nan_num(ctx):
|
def ctx_get_grad_nan_num(ctx):
|
||||||
return grad_nan_num(ctx.model)
|
return grad_nan_num(ctx.model)
|
||||||
|
|
|
||||||
|
|
@ -9,71 +9,75 @@ from khaosz.config.schedule_config import ScheduleConfig
|
||||||
|
|
||||||
class BaseScheduler(LRScheduler, ABC):
|
class BaseScheduler(LRScheduler, ABC):
|
||||||
"""Base scheduler class for all other schedulers."""
|
"""Base scheduler class for all other schedulers."""
|
||||||
|
|
||||||
def __init__(self, optimizer, last_epoch: int = -1):
|
def __init__(self, optimizer, last_epoch: int = -1):
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_lr(self) -> List[float]:
|
def get_lr(self) -> List[float]:
|
||||||
"""Calculate the current learning rate."""
|
"""Calculate the current learning rate."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def state_dict(self) -> Dict[str, Any]:
|
def state_dict(self) -> Dict[str, Any]:
|
||||||
return super().state_dict()
|
return super().state_dict()
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Dict[str, Any]):
|
def load_state_dict(self, state_dict: Dict[str, Any]):
|
||||||
super().load_state_dict(state_dict)
|
super().load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
class SchedulerFactory:
|
class SchedulerFactory:
|
||||||
"""Factory class for creating learning rate schedulers.
|
"""Factory class for creating learning rate schedulers.
|
||||||
|
|
||||||
Supports decorator-based registration for extensible scheduler types.
|
Supports decorator-based registration for extensible scheduler types.
|
||||||
Also supports creation from ScheduleConfig objects.
|
Also supports creation from ScheduleConfig objects.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
@SchedulerFactory.register("custom")
|
@SchedulerFactory.register("custom")
|
||||||
class CustomScheduler(BaseScheduler):
|
class CustomScheduler(BaseScheduler):
|
||||||
...
|
...
|
||||||
|
|
||||||
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs)
|
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs)
|
||||||
|
|
||||||
# Or from config
|
# Or from config
|
||||||
config = CosineScheduleConfig(total_steps=10000)
|
config = CosineScheduleConfig(total_steps=10000)
|
||||||
scheduler = SchedulerFactory.load(optimizer, config)
|
scheduler = SchedulerFactory.load(optimizer, config)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {}
|
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(cls, name: str):
|
def register(cls, name: str):
|
||||||
"""Decorator to register a new scheduler class.
|
"""Decorator to register a new scheduler class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Registration name for the scheduler
|
name: Registration name for the scheduler
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Decorator function that registers the scheduler class
|
Decorator function that registers the scheduler class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
|
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
|
||||||
if not issubclass(scheduler_cls, BaseScheduler):
|
if not issubclass(scheduler_cls, BaseScheduler):
|
||||||
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
raise TypeError(
|
||||||
|
f"{scheduler_cls.__name__} must inherit from BaseScheduler"
|
||||||
|
)
|
||||||
cls.SCHEDULER_MAP[name] = scheduler_cls
|
cls.SCHEDULER_MAP[name] = scheduler_cls
|
||||||
return scheduler_cls
|
return scheduler_cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler:
|
def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler:
|
||||||
"""Create a scheduler instance by type name.
|
"""Create a scheduler instance by type name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer: PyTorch optimizer
|
optimizer: PyTorch optimizer
|
||||||
schedule_type: Type of scheduler ("cosine", "sgdr")
|
schedule_type: Type of scheduler ("cosine", "sgdr")
|
||||||
**kwargs: Arguments passed to the scheduler constructor
|
**kwargs: Arguments passed to the scheduler constructor
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Scheduler instance
|
Scheduler instance
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If schedule_type is not supported
|
ValueError: If schedule_type is not supported
|
||||||
"""
|
"""
|
||||||
|
|
@ -82,25 +86,25 @@ class SchedulerFactory:
|
||||||
f"Unknown schedule type: '{schedule_type}'. "
|
f"Unknown schedule type: '{schedule_type}'. "
|
||||||
f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}"
|
f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
|
scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
|
||||||
return scheduler_cls(optimizer, **kwargs)
|
return scheduler_cls(optimizer, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
|
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
|
||||||
"""Create a scheduler from a ScheduleConfig object.
|
"""Create a scheduler from a ScheduleConfig object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer: PyTorch optimizer
|
optimizer: PyTorch optimizer
|
||||||
schedule_config: ScheduleConfig instance
|
schedule_config: ScheduleConfig instance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Scheduler instance
|
Scheduler instance
|
||||||
"""
|
"""
|
||||||
kwargs = schedule_config.get_kwargs()
|
kwargs = schedule_config.get_kwargs()
|
||||||
schedule_type = kwargs.pop("schedule_type")
|
schedule_type = kwargs.pop("schedule_type")
|
||||||
return SchedulerFactory.create(optimizer, schedule_type, **kwargs)
|
return SchedulerFactory.create(optimizer, schedule_type, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def available_types(cls) -> list:
|
def available_types(cls) -> list:
|
||||||
"""Return list of registered scheduler type names."""
|
"""Return list of registered scheduler type names."""
|
||||||
|
|
@ -114,22 +118,21 @@ class SchedulerFactory:
|
||||||
@SchedulerFactory.register("cosine")
|
@SchedulerFactory.register("cosine")
|
||||||
class CosineScheduler(BaseScheduler):
|
class CosineScheduler(BaseScheduler):
|
||||||
"""Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler."""
|
"""Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
optimizer,
|
optimizer,
|
||||||
warmup_steps: int,
|
warmup_steps: int,
|
||||||
lr_decay_steps: int,
|
lr_decay_steps: int,
|
||||||
min_rate: float = 0.05,
|
min_rate: float = 0.05,
|
||||||
last_epoch: int = -1
|
last_epoch: int = -1,
|
||||||
):
|
):
|
||||||
self.warmup_steps = warmup_steps
|
self.warmup_steps = warmup_steps
|
||||||
self.lr_decay_steps = lr_decay_steps
|
self.lr_decay_steps = lr_decay_steps
|
||||||
self.min_rate = min_rate
|
self.min_rate = min_rate
|
||||||
self.total_steps = warmup_steps + lr_decay_steps
|
self.total_steps = warmup_steps + lr_decay_steps
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
|
||||||
def get_lr(self) -> List[float]:
|
def get_lr(self) -> List[float]:
|
||||||
# warmup
|
# warmup
|
||||||
if self.last_epoch < self.warmup_steps:
|
if self.last_epoch < self.warmup_steps:
|
||||||
|
|
@ -142,46 +145,47 @@ class CosineScheduler(BaseScheduler):
|
||||||
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress))
|
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress))
|
||||||
decay_factor = max(self.min_rate, cosine_decay)
|
decay_factor = max(self.min_rate, cosine_decay)
|
||||||
return [base_lr * decay_factor for base_lr in self.base_lrs]
|
return [base_lr * decay_factor for base_lr in self.base_lrs]
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
state = super().state_dict()
|
state = super().state_dict()
|
||||||
state.update({
|
state.update(
|
||||||
'warmup_steps': self.warmup_steps,
|
{
|
||||||
'lr_decay_steps': self.lr_decay_steps,
|
"warmup_steps": self.warmup_steps,
|
||||||
'min_rate': self.min_rate,
|
"lr_decay_steps": self.lr_decay_steps,
|
||||||
'total_steps': self.total_steps,
|
"min_rate": self.min_rate,
|
||||||
})
|
"total_steps": self.total_steps,
|
||||||
|
}
|
||||||
|
)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
self.warmup_steps = state_dict.pop('warmup_steps')
|
self.warmup_steps = state_dict.pop("warmup_steps")
|
||||||
self.lr_decay_steps = state_dict.pop('lr_decay_steps')
|
self.lr_decay_steps = state_dict.pop("lr_decay_steps")
|
||||||
self.min_rate = state_dict.pop('min_rate')
|
self.min_rate = state_dict.pop("min_rate")
|
||||||
self.total_steps = state_dict.pop('total_steps')
|
self.total_steps = state_dict.pop("total_steps")
|
||||||
super().load_state_dict(state_dict)
|
super().load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
@SchedulerFactory.register("sgdr")
|
@SchedulerFactory.register("sgdr")
|
||||||
class SGDRScheduler(BaseScheduler):
|
class SGDRScheduler(BaseScheduler):
|
||||||
"""SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler."""
|
"""SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
optimizer,
|
optimizer,
|
||||||
warmup_steps: int,
|
warmup_steps: int,
|
||||||
cycle_length: int,
|
cycle_length: int,
|
||||||
min_rate: float = 0.05,
|
min_rate: float = 0.05,
|
||||||
t_mult: int = 2,
|
t_mult: int = 2,
|
||||||
last_epoch: int = -1,
|
last_epoch: int = -1,
|
||||||
):
|
):
|
||||||
self.warmup_steps = warmup_steps
|
self.warmup_steps = warmup_steps
|
||||||
self.cycle_length = cycle_length
|
self.cycle_length = cycle_length
|
||||||
self.min_rate = min_rate
|
self.min_rate = min_rate
|
||||||
self.t_mult = t_mult
|
self.t_mult = t_mult
|
||||||
|
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
# warmup
|
# warmup
|
||||||
if self.last_epoch < self.warmup_steps:
|
if self.last_epoch < self.warmup_steps:
|
||||||
|
|
@ -190,40 +194,44 @@ class SGDRScheduler(BaseScheduler):
|
||||||
|
|
||||||
# SGDR
|
# SGDR
|
||||||
steps_since_warmup = self.last_epoch - self.warmup_steps
|
steps_since_warmup = self.last_epoch - self.warmup_steps
|
||||||
|
|
||||||
# 1. Calculate current cycle and position within cycle
|
# 1. Calculate current cycle and position within cycle
|
||||||
current_cycle_length = self.cycle_length
|
current_cycle_length = self.cycle_length
|
||||||
total_cycles_length = 0
|
total_cycles_length = 0
|
||||||
cycle_num = 0
|
cycle_num = 0
|
||||||
|
|
||||||
while total_cycles_length + current_cycle_length <= steps_since_warmup:
|
while total_cycles_length + current_cycle_length <= steps_since_warmup:
|
||||||
total_cycles_length += current_cycle_length
|
total_cycles_length += current_cycle_length
|
||||||
current_cycle_length *= self.t_mult
|
current_cycle_length *= self.t_mult
|
||||||
cycle_num += 1
|
cycle_num += 1
|
||||||
|
|
||||||
steps_in_cycle = steps_since_warmup - total_cycles_length
|
steps_in_cycle = steps_since_warmup - total_cycles_length
|
||||||
|
|
||||||
# 2. Cosine annealing within the current cycle
|
# 2. Cosine annealing within the current cycle
|
||||||
cosine_factor = 0.5 * (1 + math.cos(math.pi * steps_in_cycle / current_cycle_length))
|
cosine_factor = 0.5 * (
|
||||||
|
1 + math.cos(math.pi * steps_in_cycle / current_cycle_length)
|
||||||
|
)
|
||||||
learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor
|
learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor
|
||||||
|
|
||||||
return [base_lr * learning_rate_factor for base_lr in self.base_lrs]
|
return [base_lr * learning_rate_factor for base_lr in self.base_lrs]
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
"""Returns the state of the scheduler as a dict."""
|
"""Returns the state of the scheduler as a dict."""
|
||||||
state = super().state_dict()
|
state = super().state_dict()
|
||||||
state.update({
|
state.update(
|
||||||
'warmup_steps': self.warmup_steps,
|
{
|
||||||
'cycle_length': self.cycle_length,
|
"warmup_steps": self.warmup_steps,
|
||||||
'min_rate': self.min_rate,
|
"cycle_length": self.cycle_length,
|
||||||
't_mult': self.t_mult
|
"min_rate": self.min_rate,
|
||||||
})
|
"t_mult": self.t_mult,
|
||||||
|
}
|
||||||
|
)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
"""Loads the scheduler's state."""
|
"""Loads the scheduler's state."""
|
||||||
self.warmup_steps = state_dict.pop('warmup_steps')
|
self.warmup_steps = state_dict.pop("warmup_steps")
|
||||||
self.cycle_length = state_dict.pop('cycle_length')
|
self.cycle_length = state_dict.pop("cycle_length")
|
||||||
self.min_rate = state_dict.pop('min_rate')
|
self.min_rate = state_dict.pop("min_rate")
|
||||||
self.t_mult = state_dict.pop('t_mult')
|
self.t_mult = state_dict.pop("t_mult")
|
||||||
super().load_state_dict(state_dict)
|
super().load_state_dict(state_dict)
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ def unwrap_model(model: nn.Module) -> nn.Module:
|
||||||
|
|
||||||
def create_ref_model(model: nn.Module) -> nn.Module:
|
def create_ref_model(model: nn.Module) -> nn.Module:
|
||||||
"""Create a reference model for DPO/GRPO training.
|
"""Create a reference model for DPO/GRPO training.
|
||||||
|
|
||||||
Handles DDP-wrapped models safely by unwrapping first,
|
Handles DDP-wrapped models safely by unwrapping first,
|
||||||
then creating a deep copy with frozen gradients.
|
then creating a deep copy with frozen gradients.
|
||||||
"""
|
"""
|
||||||
|
|
@ -37,25 +37,27 @@ def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
|
||||||
|
|
||||||
|
|
||||||
def get_logprobs(
|
def get_logprobs(
|
||||||
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
mask: Tensor,
|
mask: Tensor,
|
||||||
reduction: str,
|
reduction: str,
|
||||||
):
|
):
|
||||||
"""Compute token-wise log probabilities from model outputs.
|
"""Compute token-wise log probabilities from model outputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The language model
|
model: The language model
|
||||||
input_ids: Input token IDs of shape [batch_size, seq_len]
|
input_ids: Input token IDs of shape [batch_size, seq_len]
|
||||||
mask: Attention mask of shape [batch_size, seq_len]
|
mask: Attention mask of shape [batch_size, seq_len]
|
||||||
reduction: How to reduce over sequence dimension ("mean", "sum", "none")
|
reduction: How to reduce over sequence dimension ("mean", "sum", "none")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Log probabilities with reduction applied over sequence dimension
|
Log probabilities with reduction applied over sequence dimension
|
||||||
"""
|
"""
|
||||||
allowed_reductions = ["mean", "sum", "none"]
|
allowed_reductions = ["mean", "sum", "none"]
|
||||||
if reduction not in allowed_reductions:
|
if reduction not in allowed_reductions:
|
||||||
raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'")
|
raise ValueError(
|
||||||
|
f"reduction must be one of {allowed_reductions}, got '{reduction}'"
|
||||||
|
)
|
||||||
|
|
||||||
shifted_input_ids = input_ids[:, 1:]
|
shifted_input_ids = input_ids[:, 1:]
|
||||||
shifted_mask = mask[:, 1:]
|
shifted_mask = mask[:, 1:]
|
||||||
|
|
@ -64,13 +66,13 @@ def get_logprobs(
|
||||||
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||||||
|
|
||||||
token_logprobs = torch.gather(
|
token_logprobs = torch.gather(
|
||||||
log_probs,
|
log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
|
||||||
dim=-1,
|
|
||||||
index=shifted_input_ids.unsqueeze(-1)
|
|
||||||
).squeeze(-1)
|
).squeeze(-1)
|
||||||
|
|
||||||
if reduction == "mean":
|
if reduction == "mean":
|
||||||
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(dim=-1).clamp(min=1.0)
|
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(
|
||||||
|
dim=-1
|
||||||
|
).clamp(min=1.0)
|
||||||
elif reduction == "sum":
|
elif reduction == "sum":
|
||||||
return (token_logprobs * shifted_mask).sum(dim=-1)
|
return (token_logprobs * shifted_mask).sum(dim=-1)
|
||||||
else:
|
else:
|
||||||
|
|
@ -79,23 +81,25 @@ def get_logprobs(
|
||||||
|
|
||||||
class BaseStrategy(ABC):
|
class BaseStrategy(ABC):
|
||||||
"""Abstract base class for training strategies."""
|
"""Abstract base class for training strategies."""
|
||||||
|
|
||||||
def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str):
|
def __init__(
|
||||||
|
self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str
|
||||||
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
"""Compute loss for the given batch.
|
"""Compute loss for the given batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch: Dictionary containing batch tensors
|
batch: Dictionary containing batch tensors
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Computed loss tensor
|
Computed loss tensor
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
|
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
"""Allow calling strategy directly as a callable."""
|
"""Allow calling strategy directly as a callable."""
|
||||||
return self.compute_loss(batch)
|
return self.compute_loss(batch)
|
||||||
|
|
@ -103,51 +107,55 @@ class BaseStrategy(ABC):
|
||||||
|
|
||||||
class StrategyFactory:
|
class StrategyFactory:
|
||||||
"""Factory class for creating training strategy instances.
|
"""Factory class for creating training strategy instances.
|
||||||
|
|
||||||
Supports decorator-based registration for extensible strategy types.
|
Supports decorator-based registration for extensible strategy types.
|
||||||
All default strategies (seq, sft, dpo, grpo) are automatically registered.
|
All default strategies (seq, sft, dpo, grpo) are automatically registered.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
@StrategyFactory.register("custom")
|
@StrategyFactory.register("custom")
|
||||||
class CustomStrategy(BaseStrategy):
|
class CustomStrategy(BaseStrategy):
|
||||||
...
|
...
|
||||||
|
|
||||||
strategy = StrategyFactory.create(model, "custom", device)
|
strategy = StrategyFactory.create(model, "custom", device)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"})
|
SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"})
|
||||||
STRATEGY_MAP: Dict[str, type] = {}
|
STRATEGY_MAP: Dict[str, type] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(cls, name: str):
|
def register(cls, name: str):
|
||||||
"""Decorator to register a new strategy class.
|
"""Decorator to register a new strategy class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Registration name for the strategy
|
name: Registration name for the strategy
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Decorator function that registers the strategy class
|
Decorator function that registers the strategy class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(strategy_cls: type) -> type:
|
def decorator(strategy_cls: type) -> type:
|
||||||
if not issubclass(strategy_cls, BaseStrategy):
|
if not issubclass(strategy_cls, BaseStrategy):
|
||||||
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
raise TypeError(
|
||||||
|
f"{strategy_cls.__name__} must inherit from BaseStrategy"
|
||||||
|
)
|
||||||
cls.STRATEGY_MAP[name] = strategy_cls
|
cls.STRATEGY_MAP[name] = strategy_cls
|
||||||
return strategy_cls
|
return strategy_cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy:
|
def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy:
|
||||||
"""Create a strategy instance based on training type.
|
"""Create a strategy instance based on training type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Model instance for the strategy
|
model: Model instance for the strategy
|
||||||
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
||||||
device: Device to run the strategy on
|
device: Device to run the strategy on
|
||||||
**kwargs: Additional arguments passed to strategy constructor
|
**kwargs: Additional arguments passed to strategy constructor
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Strategy instance
|
Strategy instance
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If train_type is not supported
|
ValueError: If train_type is not supported
|
||||||
NotImplementedError: If train_type is in supported list but not implemented
|
NotImplementedError: If train_type is in supported list but not implemented
|
||||||
|
|
@ -157,15 +165,15 @@ class StrategyFactory:
|
||||||
f"Unknown training strategy: '{train_type}'. "
|
f"Unknown training strategy: '{train_type}'. "
|
||||||
f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}"
|
f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if train_type not in cls.STRATEGY_MAP:
|
if train_type not in cls.STRATEGY_MAP:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Strategy '{train_type}' is supported but not yet implemented."
|
f"Strategy '{train_type}' is supported but not yet implemented."
|
||||||
)
|
)
|
||||||
|
|
||||||
strategy_cls = cls.STRATEGY_MAP[train_type]
|
strategy_cls = cls.STRATEGY_MAP[train_type]
|
||||||
return strategy_cls(model, device, **kwargs)
|
return strategy_cls(model, device, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def available_strategies(cls) -> list:
|
def available_strategies(cls) -> list:
|
||||||
"""Return list of registered strategy names."""
|
"""Return list of registered strategy names."""
|
||||||
|
|
@ -179,77 +187,81 @@ class StrategyFactory:
|
||||||
@StrategyFactory.register("seq")
|
@StrategyFactory.register("seq")
|
||||||
class SEQStrategy(BaseStrategy):
|
class SEQStrategy(BaseStrategy):
|
||||||
"""Standard next-token prediction training strategy.
|
"""Standard next-token prediction training strategy.
|
||||||
|
|
||||||
Computes cross-entropy loss for next token prediction.
|
Computes cross-entropy loss for next token prediction.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, device, label_smoothing: float = 0.0):
|
def __init__(self, model, device, label_smoothing: float = 0.0):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device)
|
||||||
self.label_smoothing = label_smoothing
|
self.label_smoothing = label_smoothing
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
batch = move_to_device(batch, self.device)
|
batch = move_to_device(batch, self.device)
|
||||||
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
||||||
logits = self.model(input_ids=input_ids)["logits"]
|
logits = self.model(input_ids=input_ids)["logits"]
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
input=logits.flatten(0, 1).float(),
|
input=logits.flatten(0, 1).float(),
|
||||||
target=target_ids.flatten(),
|
target=target_ids.flatten(),
|
||||||
label_smoothing=self.label_smoothing
|
label_smoothing=self.label_smoothing,
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
@StrategyFactory.register("sft")
|
@StrategyFactory.register("sft")
|
||||||
class SFTStrategy(BaseStrategy):
|
class SFTStrategy(BaseStrategy):
|
||||||
"""Supervised Fine-tuning strategy with loss masking.
|
"""Supervised Fine-tuning strategy with loss masking.
|
||||||
|
|
||||||
Applies cross-entropy loss only to tokens where loss_mask is True.
|
Applies cross-entropy loss only to tokens where loss_mask is True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, device, label_smoothing: float = 0.0):
|
def __init__(self, model, device, label_smoothing: float = 0.0):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device)
|
||||||
self.label_smoothing = label_smoothing
|
self.label_smoothing = label_smoothing
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
batch = move_to_device(batch, self.device)
|
batch = move_to_device(batch, self.device)
|
||||||
input_ids, target_ids, loss_mask = batch["input_ids"], batch["target_ids"], batch["loss_mask"]
|
input_ids, target_ids, loss_mask = (
|
||||||
|
batch["input_ids"],
|
||||||
|
batch["target_ids"],
|
||||||
|
batch["loss_mask"],
|
||||||
|
)
|
||||||
|
|
||||||
ignore_index = -100
|
ignore_index = -100
|
||||||
logits = self.model(input_ids=input_ids)["logits"]
|
logits = self.model(input_ids=input_ids)["logits"]
|
||||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
input=logits.flatten(0, 1).float(),
|
input=logits.flatten(0, 1).float(),
|
||||||
target=target_ids.flatten(),
|
target=target_ids.flatten(),
|
||||||
ignore_index=ignore_index,
|
ignore_index=ignore_index,
|
||||||
label_smoothing=self.label_smoothing
|
label_smoothing=self.label_smoothing,
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
@StrategyFactory.register("dpo")
|
@StrategyFactory.register("dpo")
|
||||||
class DPOStrategy(BaseStrategy):
|
class DPOStrategy(BaseStrategy):
|
||||||
"""Direct Preference Optimization strategy.
|
"""Direct Preference Optimization strategy.
|
||||||
|
|
||||||
Implements the DPO loss from the paper "Direct Preference Optimization".
|
Implements the DPO loss from the paper "Direct Preference Optimization".
|
||||||
Uses a reference model to compute KL divergence penalty.
|
Uses a reference model to compute KL divergence penalty.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
device: str,
|
device: str,
|
||||||
beta: float = 0.1,
|
beta: float = 0.1,
|
||||||
reduction: str = "mean",
|
reduction: str = "mean",
|
||||||
):
|
):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device)
|
||||||
self.ref_model = create_ref_model(model)
|
self.ref_model = create_ref_model(model)
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
batch = move_to_device(batch, self.device)
|
batch = move_to_device(batch, self.device)
|
||||||
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
|
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
|
||||||
|
|
@ -257,17 +269,19 @@ class DPOStrategy(BaseStrategy):
|
||||||
|
|
||||||
contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
|
contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
|
||||||
contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
|
contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
|
||||||
|
|
||||||
log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction)
|
log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
log_ref = get_logprobs(self.ref_model, contact_ids, contact_mask, self.reduction)
|
log_ref = get_logprobs(
|
||||||
|
self.ref_model, contact_ids, contact_mask, self.reduction
|
||||||
log_pi_chosen = log_pi[:chosen_ids.shape[0]]
|
)
|
||||||
log_pi_rejected = log_pi[chosen_ids.shape[0]:]
|
|
||||||
log_ref_chosen = log_ref[:chosen_ids.shape[0]]
|
log_pi_chosen = log_pi[: chosen_ids.shape[0]]
|
||||||
log_ref_rejected = log_ref[chosen_ids.shape[0]:]
|
log_pi_rejected = log_pi[chosen_ids.shape[0] :]
|
||||||
|
log_ref_chosen = log_ref[: chosen_ids.shape[0]]
|
||||||
|
log_ref_rejected = log_ref[chosen_ids.shape[0] :]
|
||||||
|
|
||||||
pi_log_ratio = log_pi_chosen - log_pi_rejected
|
pi_log_ratio = log_pi_chosen - log_pi_rejected
|
||||||
ref_log_ratio = log_ref_chosen - log_ref_rejected
|
ref_log_ratio = log_ref_chosen - log_ref_rejected
|
||||||
|
|
||||||
|
|
@ -280,14 +294,14 @@ class DPOStrategy(BaseStrategy):
|
||||||
@StrategyFactory.register("grpo")
|
@StrategyFactory.register("grpo")
|
||||||
class GRPOStrategy(BaseStrategy):
|
class GRPOStrategy(BaseStrategy):
|
||||||
"""Group Relative Policy Optimization strategy.
|
"""Group Relative Policy Optimization strategy.
|
||||||
|
|
||||||
Implements GRPO with clipping and KL penalty.
|
Implements GRPO with clipping and KL penalty.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
device: str,
|
device: str,
|
||||||
clip_eps: float = 0.2,
|
clip_eps: float = 0.2,
|
||||||
kl_coef: float = 0.01,
|
kl_coef: float = 0.01,
|
||||||
group_size: int = 4,
|
group_size: int = 4,
|
||||||
|
|
@ -299,43 +313,47 @@ class GRPOStrategy(BaseStrategy):
|
||||||
self.kl_coef = kl_coef
|
self.kl_coef = kl_coef
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
batch = move_to_device(batch, self.device)
|
batch = move_to_device(batch, self.device)
|
||||||
prompts = batch["prompts"]
|
prompts = batch["prompts"]
|
||||||
responses = batch["responses"]
|
responses = batch["responses"]
|
||||||
masks = batch["masks"]
|
masks = batch["masks"]
|
||||||
rewards = batch["rewards"]
|
rewards = batch["rewards"]
|
||||||
|
|
||||||
batch_size, group_size, response_len = responses.shape
|
batch_size, group_size, response_len = responses.shape
|
||||||
responses_flat = responses.view(-1, response_len)
|
responses_flat = responses.view(-1, response_len)
|
||||||
masks_flat = masks.view(-1, response_len)
|
masks_flat = masks.view(-1, response_len)
|
||||||
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
|
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
|
||||||
|
|
||||||
# Shape: (batch_size * group_size, seq_len + response_len)
|
# Shape: (batch_size * group_size, seq_len + response_len)
|
||||||
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
|
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
|
||||||
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
|
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
|
||||||
|
|
||||||
log_probs_policy = get_logprobs(self.model, full_sequences, full_masks, self.reduction)
|
log_probs_policy = get_logprobs(
|
||||||
|
self.model, full_sequences, full_masks, self.reduction
|
||||||
|
)
|
||||||
log_probs_policy = log_probs_policy.view(batch_size, group_size)
|
log_probs_policy = log_probs_policy.view(batch_size, group_size)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction)
|
log_probs_ref = get_logprobs(
|
||||||
|
self.ref_model, full_sequences, full_masks, self.reduction
|
||||||
|
)
|
||||||
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
||||||
|
|
||||||
# Compute advantages from rewards with normalization
|
# Compute advantages from rewards with normalization
|
||||||
eps = torch.finfo(log_probs_policy.dtype).eps
|
eps = torch.finfo(log_probs_policy.dtype).eps
|
||||||
mean = rewards.mean(dim=-1, keepdim=True)
|
mean = rewards.mean(dim=-1, keepdim=True)
|
||||||
std = rewards.std(dim=-1, keepdim=True)
|
std = rewards.std(dim=-1, keepdim=True)
|
||||||
advantages = (rewards - mean) / (std + eps)
|
advantages = (rewards - mean) / (std + eps)
|
||||||
|
|
||||||
# PPO-style clipped surrogate objective
|
# PPO-style clipped surrogate objective
|
||||||
ratio = torch.exp(0) # Off-policy: policy_model = old_model
|
ratio = torch.exp(0) # Off-policy: policy_model = old_model
|
||||||
surr1 = ratio * advantages
|
surr1 = ratio * advantages
|
||||||
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||||
|
|
||||||
policy_loss = -torch.min(surr1, surr2).mean()
|
policy_loss = -torch.min(surr1, surr2).mean()
|
||||||
kl_penalty = self.kl_coef * (log_probs_policy - log_probs_ref).square().mean()
|
kl_penalty = self.kl_coef * (log_probs_policy - log_probs_ref).square().mean()
|
||||||
total_loss = policy_loss + kl_penalty
|
total_loss = policy_loss + kl_penalty
|
||||||
|
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
|
||||||
|
|
@ -18,52 +18,53 @@ from khaosz.trainer.metric_util import (
|
||||||
ctx_get_grad_norm,
|
ctx_get_grad_norm,
|
||||||
ctx_get_grad_mean,
|
ctx_get_grad_mean,
|
||||||
ctx_get_grad_std,
|
ctx_get_grad_std,
|
||||||
ctx_get_grad_nan_num
|
ctx_get_grad_nan_num,
|
||||||
)
|
)
|
||||||
from khaosz.data.serialization import Checkpoint
|
from khaosz.data.serialization import Checkpoint
|
||||||
from khaosz.trainer.train_context import TrainContext
|
from khaosz.trainer.train_context import TrainContext
|
||||||
|
|
||||||
|
|
||||||
class TrainCallback(Protocol):
|
class TrainCallback(Protocol):
|
||||||
"""
|
"""
|
||||||
Callback interface for trainer.
|
Callback interface for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def on_train_begin(self, context: TrainContext):
|
def on_train_begin(self, context: TrainContext):
|
||||||
""" Called at the beginning of training. """
|
"""Called at the beginning of training."""
|
||||||
|
|
||||||
def on_train_end(self, context: TrainContext):
|
def on_train_end(self, context: TrainContext):
|
||||||
""" Called at the end of training. """
|
"""Called at the end of training."""
|
||||||
|
|
||||||
def on_epoch_begin(self, context: TrainContext):
|
def on_epoch_begin(self, context: TrainContext):
|
||||||
""" Called at the beginning of each epoch. """
|
"""Called at the beginning of each epoch."""
|
||||||
|
|
||||||
def on_epoch_end(self, context: TrainContext):
|
def on_epoch_end(self, context: TrainContext):
|
||||||
""" Called at the end of each epoch. """
|
"""Called at the end of each epoch."""
|
||||||
|
|
||||||
def on_step_begin(self, context: TrainContext):
|
def on_step_begin(self, context: TrainContext):
|
||||||
""" Called at the beginning of each step. """
|
"""Called at the beginning of each step."""
|
||||||
|
|
||||||
def on_step_end(self, context: TrainContext):
|
def on_step_end(self, context: TrainContext):
|
||||||
""" Called at the end of each step."""
|
"""Called at the end of each step."""
|
||||||
|
|
||||||
def on_batch_begin(self, context: TrainContext):
|
def on_batch_begin(self, context: TrainContext):
|
||||||
""" Called at the beginning of each batch. """
|
"""Called at the beginning of each batch."""
|
||||||
|
|
||||||
def on_batch_end(self, context: TrainContext):
|
def on_batch_end(self, context: TrainContext):
|
||||||
""" Called at the end of each batch. """
|
"""Called at the end of each batch."""
|
||||||
|
|
||||||
def on_error(self, context: TrainContext):
|
def on_error(self, context: TrainContext):
|
||||||
""" Called when an error occurs during training. """
|
"""Called when an error occurs during training."""
|
||||||
|
|
||||||
|
|
||||||
class GradientClippingCallback(TrainCallback):
|
class GradientClippingCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Gradient clipping callback for trainer.
|
Gradient clipping callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_grad_norm: float):
|
def __init__(self, max_grad_norm: float):
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
def on_step_begin(self, context: TrainContext):
|
def on_step_begin(self, context: TrainContext):
|
||||||
_ = context
|
_ = context
|
||||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||||
|
|
@ -73,86 +74,95 @@ class SchedulerCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Scheduler callback for trainer.
|
Scheduler callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_train_begin(self, context: TrainContext):
|
def on_train_begin(self, context: TrainContext):
|
||||||
for group in context.optimizer.param_groups:
|
for group in context.optimizer.param_groups:
|
||||||
if "initial_lr" not in group:
|
if "initial_lr" not in group:
|
||||||
group["initial_lr"] = group["lr"]
|
group["initial_lr"] = group["lr"]
|
||||||
|
|
||||||
def on_batch_end(self, context: TrainContext):
|
def on_batch_end(self, context: TrainContext):
|
||||||
if context.scheduler:
|
if context.scheduler:
|
||||||
context.scheduler.step()
|
context.scheduler.step()
|
||||||
|
|
||||||
|
|
||||||
class CheckpointCallback(TrainCallback):
|
class CheckpointCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Checkpoint callback for trainer.
|
Checkpoint callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
save_dir: str,
|
save_dir: str,
|
||||||
interval: int,
|
interval: int,
|
||||||
weight_only: bool = False,
|
weight_only: bool = False,
|
||||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None
|
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
||||||
):
|
):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.weight_only = weight_only
|
self.weight_only = weight_only
|
||||||
self.state_dict_fn = state_dict_fn
|
self.state_dict_fn = state_dict_fn
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
def _save_checkpoint(self, context: TrainContext):
|
def _save_checkpoint(self, context: TrainContext):
|
||||||
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
|
save_path = os.path.join(
|
||||||
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict()
|
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
||||||
|
)
|
||||||
|
state_dict = (
|
||||||
|
self.state_dict_fn(context.model)
|
||||||
|
if self.state_dict_fn
|
||||||
|
else context.model.state_dict()
|
||||||
|
)
|
||||||
|
|
||||||
context.checkpoint = Checkpoint(
|
context.checkpoint = Checkpoint(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict, epoch=context.epoch, iteration=context.iteration
|
||||||
epoch=context.epoch,
|
|
||||||
iteration=context.iteration
|
|
||||||
)
|
)
|
||||||
|
|
||||||
context.checkpoint.save(save_path)
|
context.checkpoint.save(save_path)
|
||||||
self.last_ckpt_iter = context.iteration
|
self.last_ckpt_iter = context.iteration
|
||||||
|
|
||||||
def on_batch_end(self, context: TrainContext):
|
def on_batch_end(self, context: TrainContext):
|
||||||
if context.iteration - self.last_ckpt_iter >= self.interval:
|
if context.iteration - self.last_ckpt_iter >= self.interval:
|
||||||
self._save_checkpoint(context)
|
self._save_checkpoint(context)
|
||||||
|
|
||||||
def on_train_end(self, context: TrainContext):
|
def on_train_end(self, context: TrainContext):
|
||||||
if context.iteration != self.last_ckpt_iter:
|
if context.iteration != self.last_ckpt_iter:
|
||||||
self._save_checkpoint(context)
|
self._save_checkpoint(context)
|
||||||
|
|
||||||
def on_error(self, context: TrainContext):
|
def on_error(self, context: TrainContext):
|
||||||
self._save_checkpoint(context)
|
self._save_checkpoint(context)
|
||||||
|
|
||||||
|
|
||||||
class ProgressBarCallback(TrainCallback):
|
class ProgressBarCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Progress bar callback for trainer.
|
Progress bar callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_epoch: int):
|
def __init__(self, num_epoch: int):
|
||||||
self.num_epoch = num_epoch
|
self.num_epoch = num_epoch
|
||||||
self.progress_bar: tqdm = None
|
self.progress_bar: tqdm = None
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
def on_epoch_begin(self, context: TrainContext):
|
def on_epoch_begin(self, context: TrainContext):
|
||||||
self.progress_bar = tqdm(
|
self.progress_bar = tqdm(
|
||||||
context.dataloader,
|
context.dataloader,
|
||||||
desc=f"Epoch {context.epoch+1}/{self.num_epoch}",
|
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
||||||
dynamic_ncols=True
|
dynamic_ncols=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
def on_batch_end(self, context: TrainContext):
|
def on_batch_end(self, context: TrainContext):
|
||||||
self.progress_bar.set_postfix({
|
self.progress_bar.set_postfix(
|
||||||
"loss": f"{context.loss:.4f}",
|
{
|
||||||
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}"
|
"loss": f"{context.loss:.4f}",
|
||||||
})
|
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
|
||||||
|
}
|
||||||
|
)
|
||||||
self.progress_bar.update(1)
|
self.progress_bar.update(1)
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
def on_epoch_end(self, context: TrainContext):
|
def on_epoch_end(self, context: TrainContext):
|
||||||
_ = context
|
_ = context
|
||||||
|
|
@ -162,66 +172,65 @@ class ProgressBarCallback(TrainCallback):
|
||||||
|
|
||||||
class MetricLoggerCallback(TrainCallback):
|
class MetricLoggerCallback(TrainCallback):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
log_dir:str,
|
log_dir: str,
|
||||||
save_interval:int,
|
save_interval: int,
|
||||||
log_interval:int=10,
|
log_interval: int = 10,
|
||||||
metrics:List[str]=None
|
metrics: List[str] = None,
|
||||||
):
|
):
|
||||||
self.last_log_iter = 0
|
self.last_log_iter = 0
|
||||||
self.save_interval = save_interval
|
self.save_interval = save_interval
|
||||||
self.log_interval = log_interval
|
self.log_interval = log_interval
|
||||||
self.metrics = metrics or ['loss', 'lr']
|
self.metrics = metrics or ["loss", "lr"]
|
||||||
|
|
||||||
self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs"
|
self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs"
|
||||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
self.log_cache = []
|
self.log_cache = []
|
||||||
|
|
||||||
self._metric_funcs = {
|
self._metric_funcs = {
|
||||||
'loss': ctx_get_loss,
|
"loss": ctx_get_loss,
|
||||||
'lr': ctx_get_lr,
|
"lr": ctx_get_lr,
|
||||||
'grad_norm': ctx_get_grad_norm,
|
"grad_norm": ctx_get_grad_norm,
|
||||||
'grad_std': ctx_get_grad_std,
|
"grad_std": ctx_get_grad_std,
|
||||||
'grad_max': ctx_get_grad_max,
|
"grad_max": ctx_get_grad_max,
|
||||||
'grad_min': ctx_get_grad_min,
|
"grad_min": ctx_get_grad_min,
|
||||||
'grad_mean': ctx_get_grad_mean,
|
"grad_mean": ctx_get_grad_mean,
|
||||||
'grad_nan_num': ctx_get_grad_nan_num
|
"grad_nan_num": ctx_get_grad_nan_num,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_log_data(self, context: TrainContext):
|
def _get_log_data(self, context: TrainContext):
|
||||||
return {
|
return {
|
||||||
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
"epoch": context.epoch,
|
"epoch": context.epoch,
|
||||||
"iter": context.iteration,
|
"iter": context.iteration,
|
||||||
**{m: self._metric_funcs[m](context) for m in self.metrics}
|
**{m: self._metric_funcs[m](context) for m in self.metrics},
|
||||||
}
|
}
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
def _add_log(self, log_data):
|
def _add_log(self, log_data):
|
||||||
self.log_cache.append(log_data)
|
self.log_cache.append(log_data)
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
def _save_log(self, epoch, iter):
|
def _save_log(self, epoch, iter):
|
||||||
log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl"
|
log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl"
|
||||||
|
|
||||||
with open(log_file, 'w') as f:
|
with open(log_file, "w") as f:
|
||||||
for log in self.log_cache:
|
for log in self.log_cache:
|
||||||
f.write(json.dumps(log) + '\n')
|
f.write(json.dumps(log) + "\n")
|
||||||
|
|
||||||
def on_batch_end(self, context):
|
def on_batch_end(self, context):
|
||||||
if context.iteration % self.log_interval == 0:
|
if context.iteration % self.log_interval == 0:
|
||||||
log_data = self._get_log_data(context)
|
log_data = self._get_log_data(context)
|
||||||
self._add_log(log_data)
|
self._add_log(log_data)
|
||||||
|
|
||||||
if context.iteration - self.last_log_iter >= self.save_interval:
|
if context.iteration - self.last_log_iter >= self.save_interval:
|
||||||
self._save_log(context.epoch, context.iteration)
|
self._save_log(context.epoch, context.iteration)
|
||||||
self.last_log_iter = context.iteration
|
self.last_log_iter = context.iteration
|
||||||
|
|
||||||
def on_train_end(self, context):
|
def on_train_end(self, context):
|
||||||
if context.iteration != self.last_log_iter:
|
if context.iteration != self.last_log_iter:
|
||||||
self._save_log(context.epoch, context.iteration)
|
self._save_log(context.epoch, context.iteration)
|
||||||
|
|
||||||
def on_error(self, context):
|
def on_error(self, context):
|
||||||
self._save_log(context.epoch, context.iteration)
|
self._save_log(context.epoch, context.iteration)
|
||||||
|
|
||||||
|
|
@ -21,11 +21,11 @@ class TrainContext:
|
||||||
optimizer: Optimizer = field(default=None)
|
optimizer: Optimizer = field(default=None)
|
||||||
scheduler: LRScheduler = field(default=None)
|
scheduler: LRScheduler = field(default=None)
|
||||||
checkpoint: Checkpoint = field(default=None)
|
checkpoint: Checkpoint = field(default=None)
|
||||||
|
|
||||||
epoch: int = field(default=0)
|
epoch: int = field(default=0)
|
||||||
iteration: int = field(default=0)
|
iteration: int = field(default=0)
|
||||||
loss: float = field(default=0.0)
|
loss: float = field(default=0.0)
|
||||||
|
|
||||||
world_size: int = field(default=1)
|
world_size: int = field(default=1)
|
||||||
rank: int = field(default=0)
|
rank: int = field(default=0)
|
||||||
kwargs: dict = field(default_factory=dict)
|
kwargs: dict = field(default_factory=dict)
|
||||||
|
|
@ -39,17 +39,17 @@ class TrainContextBuilder:
|
||||||
world_size=get_world_size(),
|
world_size=get_world_size(),
|
||||||
rank=get_rank(),
|
rank=get_rank(),
|
||||||
)
|
)
|
||||||
|
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
self._context.model = self._context.model.to(device=device)
|
self._context.model = self._context.model.to(device=device)
|
||||||
|
|
||||||
if self.config.nprocs > 1:
|
if self.config.nprocs > 1:
|
||||||
fn = self.config.parallel_wrapper
|
fn = self.config.parallel_wrapper
|
||||||
self._context.model = fn(self._context.model)
|
self._context.model = fn(self._context.model)
|
||||||
|
|
||||||
self._context.optimizer = self.config.optimizer_fn(self._context.model)
|
self._context.optimizer = self.config.optimizer_fn(self._context.model)
|
||||||
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
|
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
|
||||||
|
|
||||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||||
if checkpoint is None:
|
if checkpoint is None:
|
||||||
checkpoint = Checkpoint(
|
checkpoint = Checkpoint(
|
||||||
|
|
@ -60,10 +60,10 @@ class TrainContextBuilder:
|
||||||
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
||||||
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
||||||
self._context.model.load_state_dict(checkpoint.state_dict)
|
self._context.model.load_state_dict(checkpoint.state_dict)
|
||||||
|
|
||||||
self._context.checkpoint = checkpoint
|
self._context.checkpoint = checkpoint
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_dataloader(self) -> Self:
|
def with_dataloader(self) -> Self:
|
||||||
# fix: change batch level iteration to sample level offset
|
# fix: change batch level iteration to sample level offset
|
||||||
config = self.config
|
config = self.config
|
||||||
|
|
@ -72,28 +72,28 @@ class TrainContextBuilder:
|
||||||
data_source=config.dataset,
|
data_source=config.dataset,
|
||||||
start_epoch=self._context.epoch,
|
start_epoch=self._context.epoch,
|
||||||
start_iter=sampler_offset,
|
start_iter=sampler_offset,
|
||||||
seed=config.random_seed
|
seed=config.random_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
config.dataset,
|
config.dataset,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
sampler=resumeable_sampler,
|
sampler=resumeable_sampler,
|
||||||
num_workers=config.num_workers,
|
num_workers=config.num_workers,
|
||||||
pin_memory=config.pin_memory,
|
pin_memory=config.pin_memory,
|
||||||
prefetch_factor=config.prefetch_factor
|
prefetch_factor=config.prefetch_factor,
|
||||||
)
|
)
|
||||||
self._context.dataloader = dataloader
|
self._context.dataloader = dataloader
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_strategy(self) -> Self:
|
def with_strategy(self) -> Self:
|
||||||
self._context.strategy = StrategyFactory.create(
|
self._context.strategy = StrategyFactory.create(
|
||||||
model=self._context.model,
|
model=self._context.model,
|
||||||
train_type=self.config.strategy,
|
train_type=self.config.strategy,
|
||||||
device=get_current_device(),
|
device=get_current_device(),
|
||||||
**self.config.extra_kwargs
|
**self.config.extra_kwargs,
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def build(self) -> TrainContext:
|
def build(self) -> TrainContext:
|
||||||
return self._context
|
return self._context
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,12 @@ import logging
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from khaosz.config import TrainConfig
|
from khaosz.config import TrainConfig
|
||||||
from khaosz.trainer.train_callback import (
|
from khaosz.trainer.train_callback import (
|
||||||
TrainCallback,
|
TrainCallback,
|
||||||
ProgressBarCallback,
|
ProgressBarCallback,
|
||||||
CheckpointCallback,
|
CheckpointCallback,
|
||||||
MetricLoggerCallback,
|
MetricLoggerCallback,
|
||||||
GradientClippingCallback,
|
GradientClippingCallback,
|
||||||
SchedulerCallback
|
SchedulerCallback,
|
||||||
)
|
)
|
||||||
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
||||||
from khaosz.data.serialization import Checkpoint
|
from khaosz.data.serialization import Checkpoint
|
||||||
|
|
@ -18,37 +18,39 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, train_config: TrainConfig, callbacks: Optional[List[TrainCallback]] = None
|
||||||
train_config: TrainConfig,
|
|
||||||
callbacks: Optional[List[TrainCallback]] = None
|
|
||||||
):
|
):
|
||||||
self.train_config = train_config
|
self.train_config = train_config
|
||||||
default_callbacks = self._get_default_callbacks()
|
default_callbacks = self._get_default_callbacks()
|
||||||
self.callbacks = default_callbacks + callbacks if callbacks else default_callbacks
|
self.callbacks = (
|
||||||
|
default_callbacks + callbacks if callbacks else default_callbacks
|
||||||
|
)
|
||||||
|
|
||||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||||
train_config = self.train_config
|
train_config = self.train_config
|
||||||
return [
|
return [
|
||||||
ProgressBarCallback(train_config.n_epoch),
|
ProgressBarCallback(train_config.n_epoch),
|
||||||
CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
|
CheckpointCallback(train_config.ckpt_dir, train_config.ckpt_interval),
|
||||||
MetricLoggerCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
|
MetricLoggerCallback(train_config.ckpt_dir, train_config.ckpt_interval),
|
||||||
GradientClippingCallback(train_config.max_grad_norm),
|
GradientClippingCallback(train_config.max_grad_norm),
|
||||||
SchedulerCallback(),
|
SchedulerCallback(),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||||
return (TrainContextBuilder(self.train_config)
|
return (
|
||||||
.with_checkpoint(checkpoint)
|
TrainContextBuilder(self.train_config)
|
||||||
.with_dataloader()
|
.with_checkpoint(checkpoint)
|
||||||
.with_strategy()
|
.with_dataloader()
|
||||||
.build())
|
.with_strategy()
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
method = getattr(callback, method_name, None)
|
method = getattr(callback, method_name, None)
|
||||||
if method:
|
if method:
|
||||||
method(context)
|
method(context)
|
||||||
|
|
||||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||||
config = self.train_config
|
config = self.train_config
|
||||||
spawn_parallel_fn(
|
spawn_parallel_fn(
|
||||||
|
|
@ -59,45 +61,45 @@ class Trainer:
|
||||||
master_port=config.master_port,
|
master_port=config.master_port,
|
||||||
device_type=config.device_type,
|
device_type=config.device_type,
|
||||||
device_ids=config.device_ids,
|
device_ids=config.device_ids,
|
||||||
checkpoint=checkpoint
|
checkpoint=checkpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
||||||
context = self._build_context(checkpoint)
|
context = self._build_context(checkpoint)
|
||||||
self._call_callbacks('on_train_begin', context)
|
self._call_callbacks("on_train_begin", context)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context.model.train()
|
context.model.train()
|
||||||
# 1.epoch
|
# 1.epoch
|
||||||
for epoch in range(context.epoch, self.train_config.n_epoch):
|
for epoch in range(context.epoch, self.train_config.n_epoch):
|
||||||
context.epoch = epoch
|
context.epoch = epoch
|
||||||
self._call_callbacks('on_epoch_begin', context)
|
self._call_callbacks("on_epoch_begin", context)
|
||||||
|
|
||||||
for batch in context.dataloader:
|
for batch in context.dataloader:
|
||||||
# 3. batch
|
# 3. batch
|
||||||
self._call_callbacks('on_batch_begin', context)
|
self._call_callbacks("on_batch_begin", context)
|
||||||
loss = context.strategy(batch)
|
loss = context.strategy(batch)
|
||||||
context.loss = loss.item()
|
context.loss = loss.item()
|
||||||
context.iteration += 1
|
context.iteration += 1
|
||||||
|
|
||||||
# to make the loss normalized by accumulation steps
|
# to make the loss normalized by accumulation steps
|
||||||
stand_loss = loss / self.train_config.accumulation_steps
|
stand_loss = loss / self.train_config.accumulation_steps
|
||||||
stand_loss.backward()
|
stand_loss.backward()
|
||||||
|
|
||||||
self._call_callbacks('on_batch_end', context)
|
self._call_callbacks("on_batch_end", context)
|
||||||
|
|
||||||
if context.iteration % self.train_config.accumulation_steps == 0:
|
if context.iteration % self.train_config.accumulation_steps == 0:
|
||||||
# 2. step
|
# 2. step
|
||||||
self._call_callbacks('on_step_begin', context)
|
self._call_callbacks("on_step_begin", context)
|
||||||
context.optimizer.step()
|
context.optimizer.step()
|
||||||
context.optimizer.zero_grad()
|
context.optimizer.zero_grad()
|
||||||
self._call_callbacks('on_step_end', context)
|
self._call_callbacks("on_step_end", context)
|
||||||
|
|
||||||
|
self._call_callbacks("on_epoch_end", context)
|
||||||
|
|
||||||
self._call_callbacks('on_epoch_end', context)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
||||||
self._call_callbacks('on_error', context)
|
self._call_callbacks("on_error", context)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
self._call_callbacks('on_train_end', context)
|
self._call_callbacks("on_train_end", context)
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ classifiers = [
|
||||||
urls = { Homepage = "https://github.com/ViperEkura/KHAOSZ" }
|
urls = { Homepage = "https://github.com/ViperEkura/KHAOSZ" }
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = ["pytest==9.0.2"]
|
dev = ["pytest==9.0.2", "ruff"]
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["."]
|
where = ["."]
|
||||||
|
|
@ -35,4 +35,13 @@ where = ["."]
|
||||||
extra-index-url = "https://download.pytorch.org/whl/cu126"
|
extra-index-url = "https://download.pytorch.org/whl/cu126"
|
||||||
|
|
||||||
[tool.setuptools.dynamic]
|
[tool.setuptools.dynamic]
|
||||||
version = { attr = "khaosz.__version__" }
|
version = { attr = "khaosz.__version__" }
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py312"
|
||||||
|
|
||||||
|
[tool.ruff.format]
|
||||||
|
quote-style = "double"
|
||||||
|
indent-style = "space"
|
||||||
|
skip-magic-trailing-comma = false
|
||||||
|
line-ending = "auto"
|
||||||
|
|
@ -17,14 +17,14 @@ class RandomDataset(Dataset):
|
||||||
self.length = length or int(np.random.randint(100, 200))
|
self.length = length or int(np.random.randint(100, 200))
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return {
|
return {
|
||||||
"input_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
|
"input_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
|
||||||
"target_ids": torch.randint(0, self.vocab_size, (self.max_length,))
|
"target_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -33,10 +33,10 @@ class MultiTurnDataset(Dataset):
|
||||||
self.length = length or int(np.random.randint(100, 200))
|
self.length = length or int(np.random.randint(100, 200))
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
input_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
input_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
||||||
target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
||||||
|
|
@ -54,18 +54,18 @@ class EarlyStoppingDataset(Dataset):
|
||||||
self.length = length
|
self.length = length
|
||||||
self.stop_after = stop_after
|
self.stop_after = stop_after
|
||||||
self.count = 0
|
self.count = 0
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
self.count += 1
|
self.count += 1
|
||||||
if self.count == self.stop_after:
|
if self.count == self.stop_after:
|
||||||
raise RuntimeError("Simulated early stopping")
|
raise RuntimeError("Simulated early stopping")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": torch.randint(0, 1000, (64,)),
|
"input_ids": torch.randint(0, 1000, (64,)),
|
||||||
"target_ids": torch.randint(0, 1000, (64,))
|
"target_ids": torch.randint(0, 1000, (64,)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -74,10 +74,10 @@ def base_test_env(request: pytest.FixtureRequest):
|
||||||
func_name = request.function.__name__
|
func_name = request.function.__name__
|
||||||
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
|
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
|
||||||
config_path = os.path.join(test_dir, "config.json")
|
config_path = os.path.join(test_dir, "config.json")
|
||||||
|
|
||||||
n_dim_choices = [8, 16, 32]
|
n_dim_choices = [8, 16, 32]
|
||||||
n_head_choices = [2, 4]
|
n_head_choices = [2, 4]
|
||||||
|
|
||||||
dim = int(np.random.choice(n_dim_choices))
|
dim = int(np.random.choice(n_dim_choices))
|
||||||
n_heads = int(np.random.choice(n_head_choices))
|
n_heads = int(np.random.choice(n_head_choices))
|
||||||
n_kv_heads = n_heads // 2
|
n_kv_heads = n_heads // 2
|
||||||
|
|
@ -91,16 +91,16 @@ def base_test_env(request: pytest.FixtureRequest):
|
||||||
"dim_ffn": dim_ffn,
|
"dim_ffn": dim_ffn,
|
||||||
"max_len": 1024,
|
"max_len": 1024,
|
||||||
"n_layers": 4,
|
"n_layers": 4,
|
||||||
"norm_eps": 1e-5
|
"norm_eps": 1e-5,
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config, f)
|
json.dump(config, f)
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
transformer_config = ModelConfig().load(config_path)
|
transformer_config = ModelConfig().load(config_path)
|
||||||
model = Transformer(transformer_config).to(device=device)
|
model = Transformer(transformer_config).to(device=device)
|
||||||
tokenizer = BpeTokenizer()
|
tokenizer = BpeTokenizer()
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"device": device,
|
"device": device,
|
||||||
"test_dir": str(test_dir),
|
"test_dir": str(test_dir),
|
||||||
|
|
@ -109,20 +109,23 @@ def base_test_env(request: pytest.FixtureRequest):
|
||||||
"model": model,
|
"model": model,
|
||||||
"tokenizer": tokenizer,
|
"tokenizer": tokenizer,
|
||||||
}
|
}
|
||||||
|
|
||||||
shutil.rmtree(test_dir)
|
shutil.rmtree(test_dir)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def random_dataset():
|
def random_dataset():
|
||||||
dataset = RandomDataset()
|
dataset = RandomDataset()
|
||||||
yield dataset
|
yield dataset
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def multi_turn_dataset():
|
def multi_turn_dataset():
|
||||||
dataset = MultiTurnDataset()
|
dataset = MultiTurnDataset()
|
||||||
yield dataset
|
yield dataset
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def early_stopping_dataset():
|
def early_stopping_dataset():
|
||||||
dataset = EarlyStoppingDataset()
|
dataset = EarlyStoppingDataset()
|
||||||
yield dataset
|
yield dataset
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
from khaosz.data.serialization import Checkpoint
|
from khaosz.data.serialization import Checkpoint
|
||||||
from khaosz.parallel.setup import get_rank, spawn_parallel_fn
|
from khaosz.parallel.setup import get_rank, spawn_parallel_fn
|
||||||
|
|
||||||
|
|
||||||
def test_single_process():
|
def test_single_process():
|
||||||
model = torch.nn.Linear(10, 5)
|
model = torch.nn.Linear(10, 5)
|
||||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||||
|
|
@ -14,34 +15,31 @@ def test_single_process():
|
||||||
|
|
||||||
for epoch in range(3):
|
for epoch in range(3):
|
||||||
for iteration in range(10):
|
for iteration in range(10):
|
||||||
|
|
||||||
x = torch.randn(32, 10)
|
x = torch.randn(32, 10)
|
||||||
y = torch.randn(32, 5)
|
y = torch.randn(32, 5)
|
||||||
loss = model(x).mean()
|
loss = model(x).mean()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
checkpoint = Checkpoint(
|
checkpoint = Checkpoint(state_dict=model.state_dict(), epoch=3, iteration=30)
|
||||||
state_dict=model.state_dict(),
|
|
||||||
epoch=3,
|
|
||||||
iteration=30
|
|
||||||
)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
checkpoint.save(tmpdir)
|
checkpoint.save(tmpdir)
|
||||||
|
|
||||||
loaded_checkpoint = Checkpoint.load(tmpdir)
|
loaded_checkpoint = Checkpoint.load(tmpdir)
|
||||||
|
|
||||||
assert loaded_checkpoint.epoch == 3
|
assert loaded_checkpoint.epoch == 3
|
||||||
assert loaded_checkpoint.iteration == 30
|
assert loaded_checkpoint.iteration == 30
|
||||||
|
|
||||||
|
|
||||||
def simple_training():
|
def simple_training():
|
||||||
model = torch.nn.Linear(10, 5)
|
model = torch.nn.Linear(10, 5)
|
||||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||||
scheduler = CosineAnnealingLR(optimizer, T_max=10)
|
scheduler = CosineAnnealingLR(optimizer, T_max=10)
|
||||||
|
|
||||||
for epoch in range(2):
|
for epoch in range(2):
|
||||||
for iteration in range(5):
|
for iteration in range(5):
|
||||||
x = torch.randn(16, 10)
|
x = torch.randn(16, 10)
|
||||||
|
|
@ -57,28 +55,23 @@ def simple_training():
|
||||||
epoch=2,
|
epoch=2,
|
||||||
iteration=10,
|
iteration=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
rank = get_rank()
|
rank = get_rank()
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
shared_dir = tempfile.mkdtemp()
|
shared_dir = tempfile.mkdtemp()
|
||||||
checkpoint.save(shared_dir)
|
checkpoint.save(shared_dir)
|
||||||
else:
|
else:
|
||||||
shared_dir = None
|
shared_dir = None
|
||||||
|
|
||||||
|
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dir_list = [shared_dir]
|
dir_list = [shared_dir]
|
||||||
dist.broadcast_object_list(dir_list, src=0)
|
dist.broadcast_object_list(dir_list, src=0)
|
||||||
shared_dir = dir_list[0]
|
shared_dir = dir_list[0]
|
||||||
|
|
||||||
|
|
||||||
loaded = Checkpoint.load(shared_dir)
|
loaded = Checkpoint.load(shared_dir)
|
||||||
assert loaded.epoch == 2
|
assert loaded.epoch == 2
|
||||||
|
|
||||||
|
|
||||||
def test_multi_process():
|
def test_multi_process():
|
||||||
spawn_parallel_fn(
|
spawn_parallel_fn(simple_training, world_size=2, backend="gloo")
|
||||||
simple_training,
|
|
||||||
world_size=2,
|
|
||||||
backend="gloo"
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -5,30 +5,32 @@ from khaosz.data.serialization import save_h5
|
||||||
from khaosz.data.dataset import *
|
from khaosz.data.dataset import *
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_loader_random_paths(base_test_env):
|
def test_dataset_loader_random_paths(base_test_env):
|
||||||
"""Test dataset loader with multiple random paths"""
|
"""Test dataset loader with multiple random paths"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
# Create multiple mmap dataset directories with random data
|
# Create multiple mmap dataset directories with random data
|
||||||
num_files = np.random.randint(2, 5)
|
num_files = np.random.randint(2, 5)
|
||||||
|
|
||||||
for i in range(num_files):
|
for i in range(num_files):
|
||||||
seq_length = np.random.randint(200, 400)
|
seq_length = np.random.randint(200, 400)
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64) for _ in range(10)],
|
"sequence": [
|
||||||
|
torch.randint(0, 1000, (seq_length,), dtype=torch.int64)
|
||||||
|
for _ in range(10)
|
||||||
|
],
|
||||||
}
|
}
|
||||||
save_h5(test_dir, f"data_{i}", dummy_data)
|
save_h5(test_dir, f"data_{i}", dummy_data)
|
||||||
|
|
||||||
# Test loading with multiple paths
|
# Test loading with multiple paths
|
||||||
loaded_dataset = DatasetLoader.load(
|
loaded_dataset = DatasetLoader.load(
|
||||||
train_type="seq",
|
train_type="seq",
|
||||||
load_path=test_dir,
|
load_path=test_dir,
|
||||||
window_size=64,
|
window_size=64,
|
||||||
)
|
)
|
||||||
assert loaded_dataset is not None
|
assert loaded_dataset is not None
|
||||||
assert len(loaded_dataset) > 0
|
assert len(loaded_dataset) > 0
|
||||||
|
|
||||||
# Test that we can get items without errors
|
# Test that we can get items without errors
|
||||||
for i in range(len(loaded_dataset)):
|
for i in range(len(loaded_dataset)):
|
||||||
item = loaded_dataset[i]
|
item = loaded_dataset[i]
|
||||||
|
|
@ -41,30 +43,30 @@ def test_dataset_loader_random_paths(base_test_env):
|
||||||
def test_dpo_strategy_with_random_data(base_test_env):
|
def test_dpo_strategy_with_random_data(base_test_env):
|
||||||
"""Test DPO strategy with randomized preference data"""
|
"""Test DPO strategy with randomized preference data"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
# Create DPO-style data with memory mapping format
|
# Create DPO-style data with memory mapping format
|
||||||
seq_length = np.random.randint(100, 200)
|
seq_length = np.random.randint(100, 200)
|
||||||
|
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
"chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||||
"rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
"rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||||
"chosen_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
"chosen_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
||||||
"rejected_mask": [torch.ones(seq_length, dtype=torch.bool)]
|
"rejected_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
||||||
}
|
}
|
||||||
|
|
||||||
save_h5(test_dir, "dpo_data", dummy_data)
|
save_h5(test_dir, "dpo_data", dummy_data)
|
||||||
|
|
||||||
# Load DPO dataset
|
# Load DPO dataset
|
||||||
dpo_dataset = DatasetLoader.load(
|
dpo_dataset = DatasetLoader.load(
|
||||||
train_type="dpo",
|
train_type="dpo",
|
||||||
load_path=test_dir,
|
load_path=test_dir,
|
||||||
window_size=64,
|
window_size=64,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert dpo_dataset is not None
|
assert dpo_dataset is not None
|
||||||
assert hasattr(dpo_dataset, 'fetcher')
|
assert hasattr(dpo_dataset, "fetcher")
|
||||||
assert len(dpo_dataset) > 0
|
assert len(dpo_dataset) > 0
|
||||||
|
|
||||||
# Test that we can get DPO items without errors
|
# Test that we can get DPO items without errors
|
||||||
for i in range(min(3, len(dpo_dataset))):
|
for i in range(min(3, len(dpo_dataset))):
|
||||||
item = dpo_dataset[i]
|
item = dpo_dataset[i]
|
||||||
|
|
@ -79,28 +81,28 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
||||||
def test_sft_dataset_with_random_data(base_test_env):
|
def test_sft_dataset_with_random_data(base_test_env):
|
||||||
"""Test SFT dataset with random data"""
|
"""Test SFT dataset with random data"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
# Create SFT-style data with memory mapping format
|
# Create SFT-style data with memory mapping format
|
||||||
seq_length = np.random.randint(100, 200)
|
seq_length = np.random.randint(100, 200)
|
||||||
|
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||||
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)]
|
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
||||||
}
|
}
|
||||||
|
|
||||||
save_h5(test_dir, "sft_data", dummy_data)
|
save_h5(test_dir, "sft_data", dummy_data)
|
||||||
|
|
||||||
# Load SFT dataset
|
# Load SFT dataset
|
||||||
sft_dataset = DatasetLoader.load(
|
sft_dataset = DatasetLoader.load(
|
||||||
train_type="sft",
|
train_type="sft",
|
||||||
load_path=test_dir,
|
load_path=test_dir,
|
||||||
window_size=64,
|
window_size=64,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert sft_dataset is not None
|
assert sft_dataset is not None
|
||||||
assert hasattr(sft_dataset, 'fetcher')
|
assert hasattr(sft_dataset, "fetcher")
|
||||||
assert len(sft_dataset) > 0
|
assert len(sft_dataset) > 0
|
||||||
|
|
||||||
# Test that we can get SFT items without errors
|
# Test that we can get SFT items without errors
|
||||||
for i in range(min(3, len(sft_dataset))):
|
for i in range(min(3, len(sft_dataset))):
|
||||||
item = sft_dataset[i]
|
item = sft_dataset[i]
|
||||||
|
|
@ -114,33 +116,30 @@ def test_sft_dataset_with_random_data(base_test_env):
|
||||||
def test_dataset_with_custom_stride(base_test_env):
|
def test_dataset_with_custom_stride(base_test_env):
|
||||||
"""Test dataset with custom stride parameter"""
|
"""Test dataset with custom stride parameter"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
# Create test data
|
# Create test data
|
||||||
seq_length = 200
|
seq_length = 200
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||||
}
|
}
|
||||||
|
|
||||||
save_h5(test_dir,"stride_test_data", dummy_data)
|
save_h5(test_dir, "stride_test_data", dummy_data)
|
||||||
|
|
||||||
# Test with custom stride
|
# Test with custom stride
|
||||||
custom_stride = 32
|
custom_stride = 32
|
||||||
dataset = DatasetLoader.load(
|
dataset = DatasetLoader.load(
|
||||||
train_type="seq",
|
train_type="seq", load_path=test_dir, window_size=64, stride=custom_stride
|
||||||
load_path=test_dir,
|
|
||||||
window_size=64,
|
|
||||||
stride=custom_stride
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert dataset is not None
|
assert dataset is not None
|
||||||
assert len(dataset) > 0
|
assert len(dataset) > 0
|
||||||
|
|
||||||
# With stride 32 and window 64 on 200 length data, we should get more samples
|
# With stride 32 and window 64 on 200 length data, we should get more samples
|
||||||
# than with default stride (which equals window size)
|
# than with default stride (which equals window size)
|
||||||
default_stride_dataset = DatasetLoader.load(
|
default_stride_dataset = DatasetLoader.load(
|
||||||
train_type="seq",
|
train_type="seq",
|
||||||
load_path=test_dir,
|
load_path=test_dir,
|
||||||
window_size=64,
|
window_size=64,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(dataset) > len(default_stride_dataset)
|
assert len(dataset) > len(default_stride_dataset)
|
||||||
|
|
|
||||||
|
|
@ -1,30 +1,32 @@
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
from khaosz.data import *
|
from khaosz.data import *
|
||||||
|
|
||||||
|
|
||||||
def test_random_sampler_consistency(random_dataset):
|
def test_random_sampler_consistency(random_dataset):
|
||||||
"""Test RandomSampler produces consistent results with same seed"""
|
"""Test RandomSampler produces consistent results with same seed"""
|
||||||
dataset = random_dataset
|
dataset = random_dataset
|
||||||
|
|
||||||
# Create two samplers with same seed
|
# Create two samplers with same seed
|
||||||
sampler1 = ResumableDistributedSampler(dataset, seed=42)
|
sampler1 = ResumableDistributedSampler(dataset, seed=42)
|
||||||
sampler2 = ResumableDistributedSampler(dataset, seed=42)
|
sampler2 = ResumableDistributedSampler(dataset, seed=42)
|
||||||
|
|
||||||
indices1 = list(iter(sampler1))
|
indices1 = list(iter(sampler1))
|
||||||
indices2 = list(iter(sampler2))
|
indices2 = list(iter(sampler2))
|
||||||
|
|
||||||
assert indices1 == indices2
|
assert indices1 == indices2
|
||||||
|
|
||||||
|
|
||||||
def test_random_sampler_different_seeds(random_dataset):
|
def test_random_sampler_different_seeds(random_dataset):
|
||||||
"""Test RandomSampler produces different results with different seeds"""
|
"""Test RandomSampler produces different results with different seeds"""
|
||||||
dataset = random_dataset
|
dataset = random_dataset
|
||||||
|
|
||||||
# Create two samplers with different seeds
|
# Create two samplers with different seeds
|
||||||
sampler1 = ResumableDistributedSampler(dataset, seed=42)
|
sampler1 = ResumableDistributedSampler(dataset, seed=42)
|
||||||
sampler2 = ResumableDistributedSampler(dataset, seed=123)
|
sampler2 = ResumableDistributedSampler(dataset, seed=123)
|
||||||
|
|
||||||
indices1 = list(iter(sampler1))
|
indices1 = list(iter(sampler1))
|
||||||
indices2 = list(iter(sampler2))
|
indices2 = list(iter(sampler2))
|
||||||
|
|
||||||
# Very high probability they should be different
|
# Very high probability they should be different
|
||||||
assert indices1 != indices2
|
assert indices1 != indices2
|
||||||
|
|
||||||
|
|
@ -33,20 +35,20 @@ def test_sampler_across_epochs(random_dataset):
|
||||||
"""Test sampler behavior across multiple epochs"""
|
"""Test sampler behavior across multiple epochs"""
|
||||||
dataset = random_dataset
|
dataset = random_dataset
|
||||||
n = len(dataset)
|
n = len(dataset)
|
||||||
|
|
||||||
sampler = ResumableDistributedSampler(dataset, seed=42)
|
sampler = ResumableDistributedSampler(dataset, seed=42)
|
||||||
|
|
||||||
# Get indices for first epoch
|
# Get indices for first epoch
|
||||||
epoch1_indices = list(iter(sampler))
|
epoch1_indices = list(iter(sampler))
|
||||||
assert len(epoch1_indices) == n
|
assert len(epoch1_indices) == n
|
||||||
|
|
||||||
# Get indices for second epoch
|
# Get indices for second epoch
|
||||||
epoch2_indices = list(iter(sampler))
|
epoch2_indices = list(iter(sampler))
|
||||||
assert len(epoch2_indices) == n
|
assert len(epoch2_indices) == n
|
||||||
|
|
||||||
# Check that epochs have different order (should be random)
|
# Check that epochs have different order (should be random)
|
||||||
assert epoch1_indices != epoch2_indices
|
assert epoch1_indices != epoch2_indices
|
||||||
|
|
||||||
# Check that all indices are present in each epoch
|
# Check that all indices are present in each epoch
|
||||||
assert set(epoch1_indices) == set(range(n))
|
assert set(epoch1_indices) == set(range(n))
|
||||||
assert set(epoch2_indices) == set(range(n))
|
assert set(epoch2_indices) == set(range(n))
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from khaosz.data import *
|
||||||
from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore
|
from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore
|
||||||
from tokenizers import pre_tokenizers
|
from tokenizers import pre_tokenizers
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_env(request: pytest.FixtureRequest):
|
def test_env(request: pytest.FixtureRequest):
|
||||||
func_name = request.function.__name__
|
func_name = request.function.__name__
|
||||||
|
|
@ -19,7 +20,7 @@ def test_env(request: pytest.FixtureRequest):
|
||||||
config_path = os.path.join(test_dir, "config.json")
|
config_path = os.path.join(test_dir, "config.json")
|
||||||
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
|
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
|
||||||
model_path = os.path.join(test_dir, "model.safetensors")
|
model_path = os.path.join(test_dir, "model.safetensors")
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"vocab_size": 1000,
|
"vocab_size": 1000,
|
||||||
"dim": 128,
|
"dim": 128,
|
||||||
|
|
@ -28,20 +29,20 @@ def test_env(request: pytest.FixtureRequest):
|
||||||
"dim_ffn": 256,
|
"dim_ffn": 256,
|
||||||
"max_len": 64,
|
"max_len": 64,
|
||||||
"n_layers": 2,
|
"n_layers": 2,
|
||||||
"norm_eps": 1e-5
|
"norm_eps": 1e-5,
|
||||||
}
|
}
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config, f)
|
json.dump(config, f)
|
||||||
|
|
||||||
tokenizer = BpeTokenizer()
|
tokenizer = BpeTokenizer()
|
||||||
sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
|
sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
|
||||||
tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
|
tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
|
||||||
tokenizer.save(tokenizer_path)
|
tokenizer.save(tokenizer_path)
|
||||||
|
|
||||||
transformer_config = ModelConfig().load(config_path)
|
transformer_config = ModelConfig().load(config_path)
|
||||||
model = Transformer(transformer_config)
|
model = Transformer(transformer_config)
|
||||||
st.save_file(model.state_dict(), model_path)
|
st.save_file(model.state_dict(), model_path)
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"test_dir": test_dir,
|
"test_dir": test_dir,
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|
@ -51,47 +52,55 @@ def test_env(request: pytest.FixtureRequest):
|
||||||
|
|
||||||
shutil.rmtree(test_dir)
|
shutil.rmtree(test_dir)
|
||||||
|
|
||||||
|
|
||||||
def test_model_parameter(test_env):
|
def test_model_parameter(test_env):
|
||||||
save_dir = os.path.join(test_env["test_dir"], "save")
|
save_dir = os.path.join(test_env["test_dir"], "save")
|
||||||
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])
|
model_param = ModelParameter(
|
||||||
|
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
|
||||||
|
)
|
||||||
ModelParameter.save(model_param, save_dir)
|
ModelParameter.save(model_param, save_dir)
|
||||||
|
|
||||||
assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
|
assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
|
||||||
assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))
|
assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))
|
||||||
assert os.path.exists(os.path.join(save_dir, "config.json"))
|
assert os.path.exists(os.path.join(save_dir, "config.json"))
|
||||||
|
|
||||||
|
|
||||||
# transformer
|
# transformer
|
||||||
def test_transformer(test_env):
|
def test_transformer(test_env):
|
||||||
model = test_env["model"]
|
model = test_env["model"]
|
||||||
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size,
|
input_ids = torch.randint(
|
||||||
(4, test_env["transformer_config"].max_len))
|
0,
|
||||||
|
test_env["transformer_config"].vocab_size,
|
||||||
|
(4, test_env["transformer_config"].max_len),
|
||||||
|
)
|
||||||
output_logits = model(input_ids)["logits"]
|
output_logits = model(input_ids)["logits"]
|
||||||
target_shape = (4, test_env["transformer_config"].max_len, test_env["transformer_config"].vocab_size)
|
target_shape = (
|
||||||
|
4,
|
||||||
|
test_env["transformer_config"].max_len,
|
||||||
|
test_env["transformer_config"].vocab_size,
|
||||||
|
)
|
||||||
assert output_logits.shape == target_shape
|
assert output_logits.shape == target_shape
|
||||||
|
|
||||||
|
|
||||||
# generator
|
# generator
|
||||||
def test_embedding_encoder_core(test_env):
|
def test_embedding_encoder_core(test_env):
|
||||||
parameter = ModelParameter(
|
parameter = ModelParameter(
|
||||||
test_env["model"],
|
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
|
||||||
test_env["tokenizer"],
|
|
||||||
test_env["transformer_config"]
|
|
||||||
)
|
)
|
||||||
encoder = EmbeddingEncoderCore(parameter)
|
encoder = EmbeddingEncoderCore(parameter)
|
||||||
|
|
||||||
single_emb = encoder.encode("测试文本")
|
single_emb = encoder.encode("测试文本")
|
||||||
assert isinstance(single_emb, torch.Tensor)
|
assert isinstance(single_emb, torch.Tensor)
|
||||||
assert single_emb.shape[-1] == test_env["transformer_config"].dim
|
assert single_emb.shape[-1] == test_env["transformer_config"].dim
|
||||||
|
|
||||||
|
|
||||||
batch_emb = encoder.encode(["测试1", "测试2"])
|
batch_emb = encoder.encode(["测试1", "测试2"])
|
||||||
assert isinstance(batch_emb, list)
|
assert isinstance(batch_emb, list)
|
||||||
assert len(batch_emb) == 2
|
assert len(batch_emb) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_generator_core(test_env):
|
def test_generator_core(test_env):
|
||||||
parameter = ModelParameter(
|
parameter = ModelParameter(
|
||||||
test_env["model"],
|
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
|
||||||
test_env["tokenizer"],
|
|
||||||
test_env["transformer_config"]
|
|
||||||
)
|
)
|
||||||
generator = GeneratorCore(parameter)
|
generator = GeneratorCore(parameter)
|
||||||
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10))
|
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10))
|
||||||
|
|
@ -102,8 +111,8 @@ def test_generator_core(test_env):
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
attn_mask=None,
|
attn_mask=None,
|
||||||
kv_caches=None,
|
kv_caches=None,
|
||||||
start_pos=0
|
start_pos=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert next_token_id.shape == (4, 1)
|
assert next_token_id.shape == (4, 1)
|
||||||
assert cache_increase == 10
|
assert cache_increase == 10
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ def transformer_test_env():
|
||||||
"""创建Transformer测试专用环境"""
|
"""创建Transformer测试专用环境"""
|
||||||
test_dir = tempfile.mkdtemp(prefix="transformer_test_")
|
test_dir = tempfile.mkdtemp(prefix="transformer_test_")
|
||||||
config_path = os.path.join(test_dir, "config.json")
|
config_path = os.path.join(test_dir, "config.json")
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"vocab_size": 1000,
|
"vocab_size": 1000,
|
||||||
"dim": 128,
|
"dim": 128,
|
||||||
|
|
@ -22,18 +22,14 @@ def transformer_test_env():
|
||||||
"dim_ffn": 256,
|
"dim_ffn": 256,
|
||||||
"max_len": 64,
|
"max_len": 64,
|
||||||
"n_layers": 2,
|
"n_layers": 2,
|
||||||
"norm_eps": 1e-5
|
"norm_eps": 1e-5,
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config, f)
|
json.dump(config, f)
|
||||||
|
|
||||||
yield {
|
yield {"test_dir": test_dir, "config_path": config_path, "config": config}
|
||||||
"test_dir": test_dir,
|
|
||||||
"config_path": config_path,
|
|
||||||
"config": config
|
|
||||||
}
|
|
||||||
|
|
||||||
if os.path.exists(test_dir):
|
if os.path.exists(test_dir):
|
||||||
try:
|
try:
|
||||||
for file in os.listdir(test_dir):
|
for file in os.listdir(test_dir):
|
||||||
|
|
@ -46,74 +42,75 @@ def transformer_test_env():
|
||||||
def test_tie_weight_init(transformer_test_env):
|
def test_tie_weight_init(transformer_test_env):
|
||||||
config_path = transformer_test_env["config_path"]
|
config_path = transformer_test_env["config_path"]
|
||||||
config_data = transformer_test_env["config"].copy()
|
config_data = transformer_test_env["config"].copy()
|
||||||
|
|
||||||
# case 1: tie weight
|
# case 1: tie weight
|
||||||
config_data["tie_weight"] = True
|
config_data["tie_weight"] = True
|
||||||
|
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
config = ModelConfig().load(config_path)
|
config = ModelConfig().load(config_path)
|
||||||
model = Transformer(config)
|
model = Transformer(config)
|
||||||
|
|
||||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
||||||
|
|
||||||
original_weight = model.embed_tokens.weight.clone()
|
original_weight = model.embed_tokens.weight.clone()
|
||||||
model.embed_tokens.weight.data[0, 0] = 100.0
|
model.embed_tokens.weight.data[0, 0] = 100.0
|
||||||
|
|
||||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
assert not torch.equal(model.lm_head.weight, original_weight)
|
assert not torch.equal(model.lm_head.weight, original_weight)
|
||||||
|
|
||||||
# case 2: not tie weight
|
# case 2: not tie weight
|
||||||
config_data["tie_weight"] = False
|
config_data["tie_weight"] = False
|
||||||
|
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
config = ModelConfig().load(config_path)
|
config = ModelConfig().load(config_path)
|
||||||
model = Transformer(config)
|
model = Transformer(config)
|
||||||
|
|
||||||
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
||||||
|
|
||||||
original_weight = model.embed_tokens.weight.clone()
|
original_weight = model.embed_tokens.weight.clone()
|
||||||
model.embed_tokens.weight.data[0, 0] = 100.0
|
model.embed_tokens.weight.data[0, 0] = 100.0
|
||||||
|
|
||||||
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
assert not torch.equal(model.lm_head.weight, original_weight)
|
assert not torch.equal(model.lm_head.weight, original_weight)
|
||||||
|
|
||||||
|
|
||||||
def test_model_save_load_with_tie_weight(transformer_test_env):
|
def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
test_dir = transformer_test_env["test_dir"]
|
test_dir = transformer_test_env["test_dir"]
|
||||||
model_path = os.path.join(test_dir, "model.safetensors")
|
model_path = os.path.join(test_dir, "model.safetensors")
|
||||||
|
|
||||||
config_data = transformer_test_env["config"].copy()
|
config_data = transformer_test_env["config"].copy()
|
||||||
|
|
||||||
# case 1: tie weight
|
# case 1: tie weight
|
||||||
config_data["tie_weight"] = True
|
config_data["tie_weight"] = True
|
||||||
config_path = os.path.join(test_dir, "config.json")
|
config_path = os.path.join(test_dir, "config.json")
|
||||||
|
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
config = ModelConfig().load(config_path)
|
config = ModelConfig().load(config_path)
|
||||||
original_model = Transformer(config)
|
original_model = Transformer(config)
|
||||||
|
|
||||||
st.save_file(original_model.state_dict(), model_path)
|
st.save_file(original_model.state_dict(), model_path)
|
||||||
|
|
||||||
loaded_config = ModelConfig().load(config_path)
|
loaded_config = ModelConfig().load(config_path)
|
||||||
model = Transformer(loaded_config)
|
model = Transformer(loaded_config)
|
||||||
model.load_state_dict(st.load_file(model_path))
|
model.load_state_dict(st.load_file(model_path))
|
||||||
|
|
||||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
||||||
assert "lm_head.weight" not in model.state_dict()
|
assert "lm_head.weight" not in model.state_dict()
|
||||||
|
|
||||||
# case 2: not tie weight (form tie-weight state dict load)
|
# case 2: not tie weight (form tie-weight state dict load)
|
||||||
config_data["tie_weight"] = False
|
config_data["tie_weight"] = False
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
loaded_config = ModelConfig().load(config_path)
|
loaded_config = ModelConfig().load(config_path)
|
||||||
model = Transformer(loaded_config)
|
model = Transformer(loaded_config)
|
||||||
model.load_state_dict(st.load_file(model_path))
|
model.load_state_dict(st.load_file(model_path))
|
||||||
|
|
@ -121,4 +118,3 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
||||||
assert "lm_head.weight" in model.state_dict()
|
assert "lm_head.weight" in model.state_dict()
|
||||||
|
|
||||||
|
|
@ -1,16 +1,14 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from khaosz.parallel import (
|
from khaosz.parallel import get_rank, only_on_rank, spawn_parallel_fn
|
||||||
get_rank,
|
|
||||||
only_on_rank,
|
|
||||||
spawn_parallel_fn
|
|
||||||
)
|
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
def _test_only_on_rank_helper():
|
def _test_only_on_rank_helper():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def only_on_rank():
|
def only_on_rank():
|
||||||
result = _test_only_on_rank_helper()
|
result = _test_only_on_rank_helper()
|
||||||
if get_rank() == 0:
|
if get_rank() == 0:
|
||||||
|
|
@ -18,22 +16,17 @@ def only_on_rank():
|
||||||
else:
|
else:
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
def all_reduce():
|
def all_reduce():
|
||||||
x = torch.tensor([get_rank()], dtype=torch.int)
|
x = torch.tensor([get_rank()], dtype=torch.int)
|
||||||
dist.all_reduce(x, op=dist.ReduceOp.SUM)
|
dist.all_reduce(x, op=dist.ReduceOp.SUM)
|
||||||
expected_sum = sum(range(dist.get_world_size()))
|
expected_sum = sum(range(dist.get_world_size()))
|
||||||
assert x.item() == expected_sum
|
assert x.item() == expected_sum
|
||||||
|
|
||||||
|
|
||||||
def test_spawn_only_on_rank():
|
def test_spawn_only_on_rank():
|
||||||
spawn_parallel_fn(
|
spawn_parallel_fn(only_on_rank, world_size=2, backend="gloo")
|
||||||
only_on_rank,
|
|
||||||
world_size=2,
|
|
||||||
backend="gloo"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_spawn_all_reduce():
|
def test_spawn_all_reduce():
|
||||||
spawn_parallel_fn(
|
spawn_parallel_fn(all_reduce, world_size=2, backend="gloo")
|
||||||
all_reduce,
|
|
||||||
world_size=2,
|
|
||||||
backend="gloo"
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -3,57 +3,48 @@ import torch
|
||||||
from khaosz.config import *
|
from khaosz.config import *
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
|
|
||||||
|
|
||||||
def test_callback_integration(base_test_env, random_dataset):
|
def test_callback_integration(base_test_env, random_dataset):
|
||||||
"""Test that all callbacks are properly integrated"""
|
"""Test that all callbacks are properly integrated"""
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||||
warmup_steps=10,
|
|
||||||
total_steps=20
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
model=base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
strategy='seq',
|
strategy="seq",
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
checkpoint_interval=3,
|
ckpt_interval=3,
|
||||||
accumulation_steps=1,
|
accumulation_steps=1,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
device_type=base_test_env["device"]
|
device_type=base_test_env["device"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Create custom callbacks to track calls
|
# Create custom callbacks to track calls
|
||||||
callback_calls = []
|
callback_calls = []
|
||||||
|
|
||||||
class TrackingCallback(TrainCallback):
|
class TrackingCallback(TrainCallback):
|
||||||
def on_train_begin(self, context):
|
def on_train_begin(self, context):
|
||||||
callback_calls.append('on_train_begin')
|
callback_calls.append("on_train_begin")
|
||||||
|
|
||||||
def on_batch_end(self, context):
|
def on_batch_end(self, context):
|
||||||
callback_calls.append('on_batch_end')
|
callback_calls.append("on_batch_end")
|
||||||
|
|
||||||
def on_epoch_end(self, context):
|
def on_epoch_end(self, context):
|
||||||
callback_calls.append('on_epoch_end')
|
callback_calls.append("on_epoch_end")
|
||||||
|
|
||||||
|
trainer = Trainer(train_config, callbacks=[TrackingCallback()])
|
||||||
|
|
||||||
|
|
||||||
trainer = Trainer(
|
|
||||||
train_config,
|
|
||||||
callbacks=[TrackingCallback()]
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
# Verify callbacks were called
|
# Verify callbacks were called
|
||||||
assert 'on_train_begin' in callback_calls
|
assert "on_train_begin" in callback_calls
|
||||||
assert 'on_batch_end' in callback_calls
|
assert "on_batch_end" in callback_calls
|
||||||
assert 'on_epoch_end' in callback_calls
|
assert "on_epoch_end" in callback_calls
|
||||||
|
|
|
||||||
|
|
@ -5,31 +5,32 @@ from khaosz.config import *
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
from khaosz.data.serialization import Checkpoint
|
from khaosz.data.serialization import Checkpoint
|
||||||
|
|
||||||
|
|
||||||
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
"""Simulate early stopping behavior"""
|
"""Simulate early stopping behavior"""
|
||||||
|
|
||||||
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||||
|
|
||||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
strategy="seq",
|
strategy="seq",
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
model=base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
dataset=early_stopping_dataset,
|
dataset=early_stopping_dataset,
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
n_epoch=2,
|
n_epoch=2,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
checkpoint_interval=1,
|
ckpt_interval=1,
|
||||||
accumulation_steps=2,
|
accumulation_steps=2,
|
||||||
random_seed=np.random.randint(1e4),
|
random_seed=np.random.randint(1e4),
|
||||||
device_type=base_test_env["device"]
|
device_type=base_test_env["device"],
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
|
|
||||||
# Should handle early stopping gracefully
|
# Should handle early stopping gracefully
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
try:
|
try:
|
||||||
|
|
@ -37,11 +38,11 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
except Exception:
|
except Exception:
|
||||||
# Handle any exceptions
|
# Handle any exceptions
|
||||||
pass
|
pass
|
||||||
|
|
||||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
|
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
|
||||||
checkpoint = Checkpoint.load(load_dir)
|
checkpoint = Checkpoint.load(load_dir)
|
||||||
trainer.train(checkpoint)
|
trainer.train(checkpoint)
|
||||||
|
|
||||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
|
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
|
||||||
checkpoint = Checkpoint.load(load_dir)
|
checkpoint = Checkpoint.load(load_dir)
|
||||||
assert checkpoint.iteration == 10
|
assert checkpoint.iteration == 10
|
||||||
|
|
|
||||||
|
|
@ -9,39 +9,41 @@ from khaosz.data.dataset import *
|
||||||
|
|
||||||
def test_schedule_factory_random_configs():
|
def test_schedule_factory_random_configs():
|
||||||
"""Test scheduler factory with random configurations"""
|
"""Test scheduler factory with random configurations"""
|
||||||
|
|
||||||
# Create a simple model and optimizer for testing
|
# Create a simple model and optimizer for testing
|
||||||
model = torch.nn.Linear(10, 2)
|
model = torch.nn.Linear(10, 2)
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
# Test multiple random configurations
|
# Test multiple random configurations
|
||||||
for _ in range(5): # Test 5 random configurations
|
for _ in range(5): # Test 5 random configurations
|
||||||
schedule_configs = [
|
schedule_configs = [
|
||||||
CosineScheduleConfig(
|
CosineScheduleConfig(
|
||||||
warmup_steps=np.random.randint(50, 200),
|
warmup_steps=np.random.randint(50, 200),
|
||||||
total_steps=np.random.randint(1000, 5000),
|
total_steps=np.random.randint(1000, 5000),
|
||||||
min_rate=np.random.uniform(0.01, 0.1)
|
min_rate=np.random.uniform(0.01, 0.1),
|
||||||
),
|
),
|
||||||
SGDRScheduleConfig(
|
SGDRScheduleConfig(
|
||||||
warmup_steps=np.random.randint(50, 200),
|
warmup_steps=np.random.randint(50, 200),
|
||||||
cycle_length=np.random.randint(500, 2000),
|
cycle_length=np.random.randint(500, 2000),
|
||||||
t_mult=np.random.randint(1, 3),
|
t_mult=np.random.randint(1, 3),
|
||||||
min_rate=np.random.uniform(0.01, 0.1)
|
min_rate=np.random.uniform(0.01, 0.1),
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
for config in schedule_configs:
|
for config in schedule_configs:
|
||||||
# Validate configuration
|
# Validate configuration
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
# Create scheduler using factory
|
# Create scheduler using factory
|
||||||
scheduler = SchedulerFactory.load(optimizer, config)
|
scheduler = SchedulerFactory.load(optimizer, config)
|
||||||
|
|
||||||
# Verify scheduler type
|
# Verify scheduler type
|
||||||
if isinstance(config, CosineScheduleConfig):
|
if isinstance(config, CosineScheduleConfig):
|
||||||
assert isinstance(scheduler, CosineScheduler)
|
assert isinstance(scheduler, CosineScheduler)
|
||||||
assert scheduler.warmup_steps == config.warmup_steps
|
assert scheduler.warmup_steps == config.warmup_steps
|
||||||
assert scheduler.lr_decay_steps == config.total_steps - config.warmup_steps
|
assert (
|
||||||
|
scheduler.lr_decay_steps == config.total_steps - config.warmup_steps
|
||||||
|
)
|
||||||
assert scheduler.min_rate == config.min_rate
|
assert scheduler.min_rate == config.min_rate
|
||||||
elif isinstance(config, SGDRScheduleConfig):
|
elif isinstance(config, SGDRScheduleConfig):
|
||||||
assert isinstance(scheduler, SGDRScheduler)
|
assert isinstance(scheduler, SGDRScheduler)
|
||||||
|
|
@ -49,17 +51,17 @@ def test_schedule_factory_random_configs():
|
||||||
assert scheduler.cycle_length == config.cycle_length
|
assert scheduler.cycle_length == config.cycle_length
|
||||||
assert scheduler.t_mult == config.t_mult
|
assert scheduler.t_mult == config.t_mult
|
||||||
assert scheduler.min_rate == config.min_rate
|
assert scheduler.min_rate == config.min_rate
|
||||||
|
|
||||||
# Test scheduler state dict functionality
|
# Test scheduler state dict functionality
|
||||||
state_dict = scheduler.state_dict()
|
state_dict = scheduler.state_dict()
|
||||||
assert 'warmup_steps' in state_dict
|
assert "warmup_steps" in state_dict
|
||||||
assert 'min_rate' in state_dict
|
assert "min_rate" in state_dict
|
||||||
|
|
||||||
# Test scheduler step functionality
|
# Test scheduler step functionality
|
||||||
initial_lr = scheduler.get_last_lr()
|
initial_lr = scheduler.get_last_lr()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
new_lr = scheduler.get_last_lr()
|
new_lr = scheduler.get_last_lr()
|
||||||
|
|
||||||
# Learning rate should change after step, or if it's the first step,
|
# Learning rate should change after step, or if it's the first step,
|
||||||
# the epoch counter should increment
|
# the epoch counter should increment
|
||||||
assert initial_lr != new_lr or scheduler.last_epoch > -1
|
assert initial_lr != new_lr or scheduler.last_epoch > -1
|
||||||
|
|
@ -67,10 +69,10 @@ def test_schedule_factory_random_configs():
|
||||||
|
|
||||||
def test_schedule_factory_edge_cases():
|
def test_schedule_factory_edge_cases():
|
||||||
"""Test scheduler factory with edge cases and boundary conditions"""
|
"""Test scheduler factory with edge cases and boundary conditions"""
|
||||||
|
|
||||||
model = torch.nn.Linear(10, 2)
|
model = torch.nn.Linear(10, 2)
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
# Test edge cases for CosineScheduleConfig
|
# Test edge cases for CosineScheduleConfig
|
||||||
edge_cases = [
|
edge_cases = [
|
||||||
# Minimal warmup and steps
|
# Minimal warmup and steps
|
||||||
|
|
@ -80,12 +82,12 @@ def test_schedule_factory_edge_cases():
|
||||||
# Zero min_rate (edge case)
|
# Zero min_rate (edge case)
|
||||||
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.0),
|
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.0),
|
||||||
]
|
]
|
||||||
|
|
||||||
for config in edge_cases:
|
for config in edge_cases:
|
||||||
config.validate()
|
config.validate()
|
||||||
scheduler = SchedulerFactory.load(optimizer, config)
|
scheduler = SchedulerFactory.load(optimizer, config)
|
||||||
assert scheduler is not None
|
assert scheduler is not None
|
||||||
|
|
||||||
# Test multiple steps
|
# Test multiple steps
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
@ -93,7 +95,7 @@ def test_schedule_factory_edge_cases():
|
||||||
|
|
||||||
def test_schedule_factory_invalid_configs():
|
def test_schedule_factory_invalid_configs():
|
||||||
"""Test scheduler factory with invalid configurations"""
|
"""Test scheduler factory with invalid configurations"""
|
||||||
|
|
||||||
# Test invalid configurations that should raise errors
|
# Test invalid configurations that should raise errors
|
||||||
invalid_configs = [
|
invalid_configs = [
|
||||||
# Negative warmup steps
|
# Negative warmup steps
|
||||||
|
|
@ -104,7 +106,7 @@ def test_schedule_factory_invalid_configs():
|
||||||
{"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1},
|
{"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1},
|
||||||
{"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1},
|
{"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1},
|
||||||
]
|
]
|
||||||
|
|
||||||
for kwargs in invalid_configs:
|
for kwargs in invalid_configs:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
config = CosineScheduleConfig(**kwargs)
|
config = CosineScheduleConfig(**kwargs)
|
||||||
|
|
@ -113,24 +115,24 @@ def test_schedule_factory_invalid_configs():
|
||||||
|
|
||||||
def test_schedule_factory_state_persistence():
|
def test_schedule_factory_state_persistence():
|
||||||
"""Test scheduler state persistence (save/load)"""
|
"""Test scheduler state persistence (save/load)"""
|
||||||
|
|
||||||
model = torch.nn.Linear(10, 2)
|
model = torch.nn.Linear(10, 2)
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1)
|
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1)
|
||||||
scheduler = SchedulerFactory.load(optimizer, config)
|
scheduler = SchedulerFactory.load(optimizer, config)
|
||||||
|
|
||||||
# Take a few steps
|
# Take a few steps
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
# Save state
|
# Save state
|
||||||
state_dict = scheduler.state_dict()
|
state_dict = scheduler.state_dict()
|
||||||
|
|
||||||
# Create new scheduler and load state
|
# Create new scheduler and load state
|
||||||
new_scheduler = SchedulerFactory.load(optimizer, config)
|
new_scheduler = SchedulerFactory.load(optimizer, config)
|
||||||
new_scheduler.load_state_dict(state_dict)
|
new_scheduler.load_state_dict(state_dict)
|
||||||
|
|
||||||
# Verify states match
|
# Verify states match
|
||||||
assert scheduler.last_epoch == new_scheduler.last_epoch
|
assert scheduler.last_epoch == new_scheduler.last_epoch
|
||||||
assert scheduler.get_last_lr() == new_scheduler.get_last_lr()
|
assert scheduler.get_last_lr() == new_scheduler.get_last_lr()
|
||||||
|
|
|
||||||
|
|
@ -6,100 +6,94 @@ from khaosz.config import *
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
from khaosz.data.dataset import *
|
from khaosz.data.dataset import *
|
||||||
|
|
||||||
|
|
||||||
def test_different_batch_sizes(base_test_env, random_dataset):
|
def test_different_batch_sizes(base_test_env, random_dataset):
|
||||||
"""Test training with different batch sizes"""
|
"""Test training with different batch sizes"""
|
||||||
batch_sizes = [1, 2, 4, 8]
|
batch_sizes = [1, 2, 4, 8]
|
||||||
|
|
||||||
for batch_size in batch_sizes:
|
for batch_size in batch_sizes:
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||||
warmup_steps=10,
|
|
||||||
total_steps=20
|
|
||||||
)
|
|
||||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
strategy="seq",
|
strategy="seq",
|
||||||
model=base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
checkpoint_interval=5,
|
ckpt_interval=5,
|
||||||
accumulation_steps=1,
|
accumulation_steps=1,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=np.random.randint(1000),
|
random_seed=np.random.randint(1000),
|
||||||
device_type=base_test_env["device"]
|
device_type=base_test_env["device"],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert train_config.batch_size == batch_size
|
assert train_config.batch_size == batch_size
|
||||||
|
|
||||||
|
|
||||||
def test_gradient_accumulation(base_test_env, random_dataset):
|
def test_gradient_accumulation(base_test_env, random_dataset):
|
||||||
"""Test training with different gradient accumulation steps"""
|
"""Test training with different gradient accumulation steps"""
|
||||||
accumulation_steps_list = [1, 2, 4]
|
accumulation_steps_list = [1, 2, 4]
|
||||||
|
|
||||||
for accumulation_steps in accumulation_steps_list:
|
for accumulation_steps in accumulation_steps_list:
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||||
warmup_steps=10,
|
|
||||||
total_steps=20
|
|
||||||
)
|
|
||||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
strategy="seq",
|
strategy="seq",
|
||||||
model=base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
checkpoint_interval=10,
|
ckpt_interval=10,
|
||||||
accumulation_steps=accumulation_steps,
|
accumulation_steps=accumulation_steps,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
device_type=base_test_env["device"]
|
device_type=base_test_env["device"],
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
assert train_config.accumulation_steps == accumulation_steps
|
assert train_config.accumulation_steps == accumulation_steps
|
||||||
|
|
||||||
|
|
||||||
def test_memory_efficient_training(base_test_env, random_dataset):
|
def test_memory_efficient_training(base_test_env, random_dataset):
|
||||||
"""Test training with memory-efficient configurations"""
|
"""Test training with memory-efficient configurations"""
|
||||||
# Test with smaller batch sizes and gradient checkpointing
|
# Test with smaller batch sizes and gradient checkpointing
|
||||||
small_batch_configs = [
|
small_batch_configs = [
|
||||||
{"batch_size": 1, "accumulation_steps": 8},
|
{"batch_size": 1, "accumulation_steps": 8},
|
||||||
{"batch_size": 2, "accumulation_steps": 4},
|
{"batch_size": 2, "accumulation_steps": 4},
|
||||||
{"batch_size": 4, "accumulation_steps": 2}
|
{"batch_size": 4, "accumulation_steps": 2},
|
||||||
]
|
]
|
||||||
|
|
||||||
for config in small_batch_configs:
|
for config in small_batch_configs:
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||||
warmup_steps=10,
|
|
||||||
total_steps=20
|
|
||||||
)
|
|
||||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
strategy="seq",
|
strategy="seq",
|
||||||
model=base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_size=config["batch_size"],
|
batch_size=config["batch_size"],
|
||||||
checkpoint_interval=5,
|
ckpt_interval=5,
|
||||||
accumulation_steps=config["accumulation_steps"],
|
accumulation_steps=config["accumulation_steps"],
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
device_type=base_test_env["device"]
|
device_type=base_test_env["device"],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert train_config.accumulation_steps == config["accumulation_steps"]
|
assert train_config.accumulation_steps == config["accumulation_steps"]
|
||||||
|
|
|
||||||
|
|
@ -17,41 +17,47 @@ class GenerationBenchmark:
|
||||||
self,
|
self,
|
||||||
config: ModelConfig,
|
config: ModelConfig,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.float16
|
dtype: torch.dtype = torch.float16,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.device = device
|
self.device = device
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.model = Transformer(config).to(device=device, dtype=dtype)
|
self.model = Transformer(config).to(device=device, dtype=dtype)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
def _initialize_kv_cache(self, batch_size: int) -> list:
|
def _initialize_kv_cache(self, batch_size: int) -> list:
|
||||||
"""初始化KV缓存"""
|
"""初始化KV缓存"""
|
||||||
config = self.config
|
config = self.config
|
||||||
shape = (batch_size, config.max_len, config.n_layers, config.n_kv_heads, config.dim // config.n_heads)
|
shape = (
|
||||||
|
batch_size,
|
||||||
|
config.max_len,
|
||||||
|
config.n_layers,
|
||||||
|
config.n_kv_heads,
|
||||||
|
config.dim // config.n_heads,
|
||||||
|
)
|
||||||
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
||||||
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
||||||
return (k_cache, v_cache)
|
return (k_cache, v_cache)
|
||||||
|
|
||||||
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
|
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
|
||||||
prompt_ids = torch.randint(
|
prompt_ids = torch.randint(
|
||||||
low=0,
|
low=0,
|
||||||
high=self.config.vocab_size,
|
high=self.config.vocab_size,
|
||||||
size=(batch_size, prompt_length),
|
size=(batch_size, prompt_length),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.long
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_ids = torch.randint(
|
gen_ids = torch.randint(
|
||||||
low=0,
|
low=0,
|
||||||
high=self.config.vocab_size,
|
high=self.config.vocab_size,
|
||||||
size=(batch_size, total_length - prompt_length),
|
size=(batch_size, total_length - prompt_length),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.long
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt_ids, gen_ids
|
return prompt_ids, gen_ids
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_prefill_benchmark(
|
def run_prefill_benchmark(
|
||||||
self,
|
self,
|
||||||
|
|
@ -59,32 +65,38 @@ class GenerationBenchmark:
|
||||||
prompt_length: int = 512,
|
prompt_length: int = 512,
|
||||||
num_trials: int = 10,
|
num_trials: int = 10,
|
||||||
) -> BenchmarkResult:
|
) -> BenchmarkResult:
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length)
|
prompt_ids, _ = self._prepare_inputs(
|
||||||
|
batch_size, prompt_length, prompt_length
|
||||||
|
)
|
||||||
_ = self.model(prompt_ids)
|
_ = self.model(prompt_ids)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
total_time = 0.0
|
total_time = 0.0
|
||||||
total_tokens = batch_size * prompt_length * num_trials
|
total_tokens = batch_size * prompt_length * num_trials
|
||||||
|
|
||||||
for trial in range(num_trials):
|
for trial in range(num_trials):
|
||||||
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length)
|
prompt_ids, _ = self._prepare_inputs(
|
||||||
|
batch_size, prompt_length, prompt_length
|
||||||
|
)
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
start_event.record()
|
start_event.record()
|
||||||
_ = self.model(prompt_ids)
|
_ = self.model(prompt_ids)
|
||||||
end_event.record()
|
end_event.record()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
trial_time = start_event.elapsed_time(end_event) / 1000
|
trial_time = start_event.elapsed_time(end_event) / 1000
|
||||||
total_time += trial_time
|
total_time += trial_time
|
||||||
|
|
||||||
print(f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
|
print(
|
||||||
f"({prompt_length / trial_time:.1f} tokens/s)")
|
f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
|
||||||
|
f"({prompt_length / trial_time:.1f} tokens/s)"
|
||||||
|
)
|
||||||
|
|
||||||
return BenchmarkResult(
|
return BenchmarkResult(
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
total_time=total_time,
|
total_time=total_time,
|
||||||
|
|
@ -95,9 +107,9 @@ class GenerationBenchmark:
|
||||||
"prompt_length": prompt_length,
|
"prompt_length": prompt_length,
|
||||||
"dtype": self.dtype,
|
"dtype": self.dtype,
|
||||||
"device": self.device,
|
"device": self.device,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_decoding_benchmark(
|
def run_decoding_benchmark(
|
||||||
self,
|
self,
|
||||||
|
|
@ -106,39 +118,43 @@ class GenerationBenchmark:
|
||||||
gen_length: int = 128,
|
gen_length: int = 128,
|
||||||
num_trials: int = 5,
|
num_trials: int = 5,
|
||||||
) -> BenchmarkResult:
|
) -> BenchmarkResult:
|
||||||
|
|
||||||
total_time = 0.0
|
total_time = 0.0
|
||||||
total_tokens = batch_size * gen_length * num_trials
|
total_tokens = batch_size * gen_length * num_trials
|
||||||
|
|
||||||
for trial in range(num_trials):
|
for trial in range(num_trials):
|
||||||
|
prompt_ids, gen_ids = self._prepare_inputs(
|
||||||
prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length)
|
batch_size, prompt_length, prompt_length + gen_length
|
||||||
|
)
|
||||||
kv_cache = self._initialize_kv_cache(batch_size)
|
kv_cache = self._initialize_kv_cache(batch_size)
|
||||||
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
|
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
start_event.record()
|
start_event.record()
|
||||||
|
|
||||||
current_pos = prompt_length
|
current_pos = prompt_length
|
||||||
for i in range(gen_length):
|
for i in range(gen_length):
|
||||||
input_token = gen_ids[:, i:i+1]
|
input_token = gen_ids[:, i : i + 1]
|
||||||
_ = self.model(input_token, persistent_key_values=kv_cache, start_pos=current_pos)
|
_ = self.model(
|
||||||
|
input_token, persistent_key_values=kv_cache, start_pos=current_pos
|
||||||
|
)
|
||||||
current_pos += 1
|
current_pos += 1
|
||||||
|
|
||||||
end_event.record()
|
end_event.record()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
trial_time = start_event.elapsed_time(end_event) / 1000
|
trial_time = start_event.elapsed_time(end_event) / 1000
|
||||||
total_time += trial_time
|
total_time += trial_time
|
||||||
|
|
||||||
print(f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
|
||||||
f"({gen_length / trial_time:.1f} tokens/s)")
|
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
||||||
|
f"({gen_length / trial_time:.1f} tokens/s)"
|
||||||
|
)
|
||||||
|
|
||||||
return BenchmarkResult(
|
return BenchmarkResult(
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
total_time=total_time,
|
total_time=total_time,
|
||||||
|
|
@ -150,24 +166,28 @@ class GenerationBenchmark:
|
||||||
"gen_length": gen_length,
|
"gen_length": gen_length,
|
||||||
"dtype": self.dtype,
|
"dtype": self.dtype,
|
||||||
"device": self.device,
|
"device": self.device,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def print_benchmark_result(result: BenchmarkResult):
|
def print_benchmark_result(result: BenchmarkResult):
|
||||||
"""打印基准测试结果"""
|
"""打印基准测试结果"""
|
||||||
benchmark_type = result.metadata["benchmark_type"]
|
benchmark_type = result.metadata["benchmark_type"]
|
||||||
|
|
||||||
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
|
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
|
||||||
print(f"Total Tokens Processed: {result.total_tokens:,}")
|
print(f"Total Tokens Processed: {result.total_tokens:,}")
|
||||||
print(f"Time Consumed: {result.total_time:.3f}s")
|
print(f"Time Consumed: {result.total_time:.3f}s")
|
||||||
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
|
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
|
||||||
|
|
||||||
if benchmark_type == "prefill":
|
if benchmark_type == "prefill":
|
||||||
print(f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}")
|
print(
|
||||||
|
f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}"
|
||||||
|
)
|
||||||
elif benchmark_type == "decoding":
|
elif benchmark_type == "decoding":
|
||||||
print(f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}")
|
print(
|
||||||
|
f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
|
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
|
|
||||||
|
|
@ -183,16 +203,19 @@ if __name__ == "__main__":
|
||||||
n_layers=24,
|
n_layers=24,
|
||||||
norm_eps=1e-5,
|
norm_eps=1e-5,
|
||||||
)
|
)
|
||||||
|
|
||||||
benchmark = GenerationBenchmark(config)
|
benchmark = GenerationBenchmark(config)
|
||||||
|
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
print("Running Transformer Generation Benchmark")
|
print("Running Transformer Generation Benchmark")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
prefill_result = benchmark.run_prefill_benchmark(batch_size=4, prompt_length=512, num_trials=5)
|
prefill_result = benchmark.run_prefill_benchmark(
|
||||||
|
batch_size=4, prompt_length=512, num_trials=5
|
||||||
|
)
|
||||||
print_benchmark_result(prefill_result)
|
print_benchmark_result(prefill_result)
|
||||||
|
|
||||||
gen_result = benchmark.run_decoding_benchmark(batch_size=4, prompt_length=512, gen_length=128, num_trials=5)
|
gen_result = benchmark.run_decoding_benchmark(
|
||||||
|
batch_size=4, prompt_length=512, gen_length=128, num_trials=5
|
||||||
|
)
|
||||||
print_benchmark_result(gen_result)
|
print_benchmark_result(gen_result)
|
||||||
|
|
||||||
|
|
@ -21,10 +21,10 @@ def processor(
|
||||||
with disable_random_init():
|
with disable_random_init():
|
||||||
param = ModelParameter.load(model_dir)
|
param = ModelParameter.load(model_dir)
|
||||||
|
|
||||||
param.to(device='cuda', dtype=torch.bfloat16)
|
param.to(device="cuda", dtype=torch.bfloat16)
|
||||||
generator = BatchGenerator(param)
|
generator = BatchGenerator(param)
|
||||||
|
|
||||||
with open(input_json_file, "r", encoding='utf-8') as f:
|
with open(input_json_file, "r", encoding="utf-8") as f:
|
||||||
input_data = [json.loads(line) for line in f]
|
input_data = [json.loads(line) for line in f]
|
||||||
|
|
||||||
queries = [item[question_key] for item in input_data]
|
queries = [item[question_key] for item in input_data]
|
||||||
|
|
@ -41,26 +41,62 @@ def processor(
|
||||||
|
|
||||||
responses = generator.generate(request)
|
responses = generator.generate(request)
|
||||||
|
|
||||||
with open(output_json_file, "w", encoding='utf-8') as f:
|
with open(output_json_file, "w", encoding="utf-8") as f:
|
||||||
for query, response in zip(queries, responses):
|
for query, response in zip(queries, responses):
|
||||||
output_item = {question_key: query, response_key: response}
|
output_item = {question_key: query, response_key: response}
|
||||||
f.write(json.dumps(output_item, ensure_ascii=False) + '\n')
|
f.write(json.dumps(output_item, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
|
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
|
||||||
|
|
||||||
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
|
parser.add_argument(
|
||||||
parser.add_argument("--input_json_file", type=str, required=True, help="Path to the input JSONL file.")
|
"--model_dir", type=str, required=True, help="Path to the model directory."
|
||||||
parser.add_argument("--output_json_file", type=str, required=True, help="Path to the output JSONL file.")
|
)
|
||||||
parser.add_argument("--question_key", type=str, default="question", help="Key for the question in the input JSON.")
|
parser.add_argument(
|
||||||
parser.add_argument("--response_key", type=str, default="response", help="Key for the response in the output JSON.")
|
"--input_json_file",
|
||||||
parser.add_argument("--temperature", type=float, default=0.60, help="Temperature for generating responses.")
|
type=str,
|
||||||
parser.add_argument("--top_k", type=int, default=30, help="Top-k value for generating responses.")
|
required=True,
|
||||||
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.")
|
help="Path to the input JSONL file.",
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_json_file",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the output JSONL file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--question_key",
|
||||||
|
type=str,
|
||||||
|
default="question",
|
||||||
|
help="Key for the question in the input JSON.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--response_key",
|
||||||
|
type=str,
|
||||||
|
default="response",
|
||||||
|
help="Key for the response in the output JSON.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.60,
|
||||||
|
help="Temperature for generating responses.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top_k", type=int, default=30, help="Top-k value for generating responses."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top_p",
|
||||||
|
type=float,
|
||||||
|
default=0.95,
|
||||||
|
help="Top-p value for generating responses.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=1, help="Batch size for generating responses."
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
processor(**vars(args))
|
processor(**vars(args))
|
||||||
|
|
|
||||||
|
|
@ -11,89 +11,99 @@ from khaosz.inference.core import disable_random_init
|
||||||
|
|
||||||
|
|
||||||
def compute_perplexity(
|
def compute_perplexity(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
input_mask: Tensor,
|
input_mask: Tensor,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Compute the perplexity of a batch of input sequences,
|
Compute the perplexity of a batch of input sequences,
|
||||||
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
|
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output = model(input_ids, input_mask)
|
output = model(input_ids, input_mask)
|
||||||
logits = output["logits"]
|
logits = output["logits"]
|
||||||
|
|
||||||
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
||||||
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
|
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
|
||||||
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
|
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
shifted_logits.flatten(0, 1),
|
shifted_logits.flatten(0, 1), shifted_input_ids.flatten(0, 1), reduction="none"
|
||||||
shifted_input_ids.flatten(0, 1),
|
|
||||||
reduction='none'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
|
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
|
||||||
loss = loss * shifted_mask
|
loss = loss * shifted_mask
|
||||||
sentence_loss = (loss).sum(dim=1) / shifted_mask.sum(dim=1)
|
sentence_loss = (loss).sum(dim=1) / shifted_mask.sum(dim=1)
|
||||||
perplexity = torch.exp(sentence_loss) # [batch_size]
|
perplexity = torch.exp(sentence_loss) # [batch_size]
|
||||||
|
|
||||||
return perplexity
|
return perplexity
|
||||||
|
|
||||||
|
|
||||||
def process_file(
|
def process_file(
|
||||||
model_dir: str,
|
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
||||||
input_file: str,
|
|
||||||
output_file: str,
|
|
||||||
batch_size: int,
|
|
||||||
text_key: str
|
|
||||||
):
|
):
|
||||||
with disable_random_init():
|
with disable_random_init():
|
||||||
param = ModelParameter.load(model_dir)
|
param = ModelParameter.load(model_dir)
|
||||||
|
|
||||||
param.to(device='cuda', dtype=torch.bfloat16)
|
param.to(device="cuda", dtype=torch.bfloat16)
|
||||||
model = param.model
|
model = param.model
|
||||||
tokenizer = param.tokenizer
|
tokenizer = param.tokenizer
|
||||||
|
|
||||||
with open(input_file, "r", encoding='utf-8') as f:
|
with open(input_file, "r", encoding="utf-8") as f:
|
||||||
input_data = [json.loads(line) for line in f]
|
input_data = [json.loads(line) for line in f]
|
||||||
|
|
||||||
texts = [item[text_key] for item in input_data]
|
texts = [item[text_key] for item in input_data]
|
||||||
encoded_texts = [tokenizer.encode(text) for text in texts]
|
encoded_texts = [tokenizer.encode(text) for text in texts]
|
||||||
output_data = []
|
output_data = []
|
||||||
|
|
||||||
for i in tqdm(range(0, len(encoded_texts), batch_size), desc="Computing perplexity"):
|
for i in tqdm(
|
||||||
batch_encoded = encoded_texts[i:i + batch_size]
|
range(0, len(encoded_texts), batch_size), desc="Computing perplexity"
|
||||||
batch_texts = texts[i:i + batch_size]
|
):
|
||||||
|
batch_encoded = encoded_texts[i : i + batch_size]
|
||||||
|
batch_texts = texts[i : i + batch_size]
|
||||||
max_len = max(len(seq) for seq in batch_encoded)
|
max_len = max(len(seq) for seq in batch_encoded)
|
||||||
padded_ids = []
|
padded_ids = []
|
||||||
masks = []
|
masks = []
|
||||||
|
|
||||||
for seq in batch_encoded:
|
for seq in batch_encoded:
|
||||||
pad_len = max_len - len(seq)
|
pad_len = max_len - len(seq)
|
||||||
padded_seq = [tokenizer.pad_id] * pad_len + seq
|
padded_seq = [tokenizer.pad_id] * pad_len + seq
|
||||||
mask = [False] * pad_len + [True] * len(seq)
|
mask = [False] * pad_len + [True] * len(seq)
|
||||||
padded_ids.append(padded_seq)
|
padded_ids.append(padded_seq)
|
||||||
masks.append(mask)
|
masks.append(mask)
|
||||||
|
|
||||||
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
|
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
|
||||||
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
|
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
|
||||||
perplexity = compute_perplexity(model, input_ids, input_mask)
|
perplexity = compute_perplexity(model, input_ids, input_mask)
|
||||||
|
|
||||||
for text, ppl in zip(batch_texts, perplexity):
|
for text, ppl in zip(batch_texts, perplexity):
|
||||||
output_data.append({text_key: text, "ppl": float(ppl.item())})
|
output_data.append({text_key: text, "ppl": float(ppl.item())})
|
||||||
|
|
||||||
with open(output_file, "w", encoding='utf-8') as f:
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
for item in output_data:
|
for item in output_data:
|
||||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
||||||
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
|
parser.add_argument(
|
||||||
parser.add_argument("--input_file", type=str, required=True, help="Path to the input file.")
|
"--model_dir", type=str, required=True, help="Path to the model directory."
|
||||||
parser.add_argument("--output_file", type=str, required=True, help="Path to the output file.")
|
)
|
||||||
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation.")
|
parser.add_argument(
|
||||||
parser.add_argument("--text_key", type=str, default="text", help="Key for the text field in the input data.")
|
"--input_file", type=str, required=True, help="Path to the input file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_file", type=str, required=True, help="Path to the output file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=4, help="Batch size for evaluation."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_key",
|
||||||
|
type=str,
|
||||||
|
default="text",
|
||||||
|
help="Key for the text field in the input data.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
|
|
||||||
205
tools/train.py
205
tools/train.py
|
|
@ -15,40 +15,130 @@ from khaosz.parallel import get_rank
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||||
|
|
||||||
parser.add_argument("--train_type", type=str, required=True, choices=["seq", "sft", "dpo"], help="Train type.")
|
parser.add_argument(
|
||||||
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
|
"--train_type",
|
||||||
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
|
type=str,
|
||||||
|
required=True,
|
||||||
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.")
|
choices=["seq", "sft", "dpo"],
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
|
help="Train type.",
|
||||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.")
|
)
|
||||||
parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.")
|
parser.add_argument(
|
||||||
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
|
"--data_root_path",
|
||||||
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
|
type=str,
|
||||||
parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.")
|
required=True,
|
||||||
parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.")
|
help="Path to the root directory of the dataset.",
|
||||||
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
|
)
|
||||||
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
parser.add_argument(
|
||||||
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
|
"--param_path",
|
||||||
parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory")
|
type=str,
|
||||||
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.")
|
required=True,
|
||||||
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.")
|
help="Path to the model parameters or resume checkpoint.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=1, help="Batch size for training."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--accumulation_steps",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of iterations between each optimizer step.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--warmup_steps",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Number of iters between warnings.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_grad_norm",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Max gradient norm for clipping.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adamw_beta1",
|
||||||
|
type=float,
|
||||||
|
default=0.9,
|
||||||
|
help="Beta values for AdamW optimizer.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adamw_beta2",
|
||||||
|
type=float,
|
||||||
|
default=0.95,
|
||||||
|
help="Beta values for AdamW optimizer.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adamw_weight_decay",
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help="Weight decay for AdamW optimizer.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--random_seed", type=int, default=3407, help="Random seed for reproducibility."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_workers", type=int, default=4, help="Number of workers for data loading."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no_pin_memory",
|
||||||
|
action="store_false",
|
||||||
|
dest="pin_memory",
|
||||||
|
help="Disable pin memory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--window_size",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="the max length of the input sequence.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stride", type=int, default=None, help="the step size of the input sequence."
|
||||||
|
)
|
||||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||||
parser.add_argument("--label_smoothing", type=float, default=0.1, help="cross_entropy function label smoothing parameter")
|
parser.add_argument(
|
||||||
|
"--label_smoothing",
|
||||||
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
|
type=float,
|
||||||
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
|
default=0.1,
|
||||||
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
help="cross_entropy function label smoothing parameter",
|
||||||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ckpt_interval",
|
||||||
|
type=int,
|
||||||
|
default=5000,
|
||||||
|
help="Number of iters between checkpoints.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ckpt_dir",
|
||||||
|
type=str,
|
||||||
|
default="checkpoint",
|
||||||
|
help="Directory to save checkpoints.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--start_epoch", type=int, default=0, help="Start epoch for training."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--start_batch", type=int, default=0, help="Start batch for training."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
||||||
parser.add_argument("--device_type", type=str, default="cuda", help="Device type to use.")
|
parser.add_argument(
|
||||||
|
"--device_type", type=str, default="cuda", help="Device type to use."
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def ddp_wrap(model: nn.Module):
|
def ddp_wrap(model: nn.Module):
|
||||||
local_rank = get_rank()
|
local_rank = get_rank()
|
||||||
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
|
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
|
||||||
|
|
@ -56,16 +146,21 @@ def ddp_wrap(model: nn.Module):
|
||||||
model,
|
model,
|
||||||
device_ids=[local_rank],
|
device_ids=[local_rank],
|
||||||
output_device=local_rank,
|
output_device=local_rank,
|
||||||
find_unused_parameters=False
|
find_unused_parameters=False,
|
||||||
)
|
)
|
||||||
return ddp_model
|
return ddp_model
|
||||||
|
|
||||||
|
|
||||||
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
||||||
return optim.AdamW(model.parameters(), **kwargs)
|
return optim.AdamW(model.parameters(), **kwargs)
|
||||||
|
|
||||||
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
|
|
||||||
|
def create_scheduler(
|
||||||
|
optimizer: optim.Optimizer, **kwargs
|
||||||
|
) -> optim.lr_scheduler.LRScheduler:
|
||||||
return SchedulerFactory.load(optimizer, **kwargs)
|
return SchedulerFactory.load(optimizer, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def prepare_checkpoint(model: nn.Module) -> dict:
|
def prepare_checkpoint(model: nn.Module) -> dict:
|
||||||
return model.module.state_dict()
|
return model.module.state_dict()
|
||||||
|
|
||||||
|
|
@ -81,8 +176,8 @@ def train(
|
||||||
start_batch: int,
|
start_batch: int,
|
||||||
accumulation_steps: int,
|
accumulation_steps: int,
|
||||||
warmup_steps: int,
|
warmup_steps: int,
|
||||||
checkpoint_interval: int,
|
ckpt_interval: int,
|
||||||
checkpoint_dir: str,
|
ckpt_dir: str,
|
||||||
dpo_beta: float,
|
dpo_beta: float,
|
||||||
adamw_beta1: float,
|
adamw_beta1: float,
|
||||||
adamw_beta2: float,
|
adamw_beta2: float,
|
||||||
|
|
@ -99,48 +194,50 @@ def train(
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo"]
|
assert train_type in ["seq", "sft", "dpo"]
|
||||||
assert os.path.exists(param_path)
|
assert os.path.exists(param_path)
|
||||||
|
|
||||||
parameter = ModelParameter.load(param_path)
|
parameter = ModelParameter.load(param_path)
|
||||||
|
|
||||||
if window_size is None:
|
if window_size is None:
|
||||||
window_size = parameter.config.max_len
|
window_size = parameter.config.max_len
|
||||||
|
|
||||||
model = parameter.model
|
model = parameter.model
|
||||||
|
|
||||||
strategy_kwargs = {
|
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
|
||||||
"dpo_beta": dpo_beta,
|
|
||||||
"label_smoothing": label_smoothing
|
|
||||||
}
|
|
||||||
|
|
||||||
dataset = DatasetLoader.load(
|
dataset = DatasetLoader.load(
|
||||||
train_type=train_type,
|
train_type=train_type,
|
||||||
load_path=data_root_path,
|
load_path=data_root_path,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
stride=stride
|
stride=stride,
|
||||||
)
|
|
||||||
|
|
||||||
schedule_config = CosineScheduleConfig(
|
|
||||||
warmup_steps=warmup_steps,
|
|
||||||
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer_fn = partial(create_optimizer,
|
schedule_config = CosineScheduleConfig(
|
||||||
**{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay})
|
warmup_steps=warmup_steps,
|
||||||
scheduler_fn = partial(create_scheduler,
|
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
|
||||||
**{"schedule_config": schedule_config})
|
)
|
||||||
|
|
||||||
|
optimizer_fn = partial(
|
||||||
|
create_optimizer,
|
||||||
|
**{
|
||||||
|
"lr": max_lr,
|
||||||
|
"betas": (adamw_beta1, adamw_beta2),
|
||||||
|
"weight_decay": adamw_weight_decay,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config})
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
model=model,
|
model=model,
|
||||||
strategy=train_type,
|
strategy=train_type,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
checkpoint_dir=checkpoint_dir,
|
ckpt_dir=ckpt_dir,
|
||||||
n_epoch=n_epoch,
|
n_epoch=n_epoch,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
start_epoch=start_epoch,
|
start_epoch=start_epoch,
|
||||||
start_batch=start_batch,
|
start_batch=start_batch,
|
||||||
checkpoint_interval=checkpoint_interval,
|
ckpt_interval=ckpt_interval,
|
||||||
accumulation_steps=accumulation_steps,
|
accumulation_steps=accumulation_steps,
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
|
|
@ -152,11 +249,11 @@ def train(
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
extra_kwargs=strategy_kwargs,
|
extra_kwargs=strategy_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
train(**vars(args))
|
train(**vars(args))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue