style: 使用ruff 工具优化代码风格
This commit is contained in:
parent
345fd2f091
commit
426af2d75f
|
|
@ -0,0 +1,27 @@
|
||||||
|
name: Lint
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python 3.12
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install .[dev]
|
||||||
|
|
||||||
|
- name: Check formatting with ruff
|
||||||
|
run: |
|
||||||
|
ruff format --check .
|
||||||
|
|
@ -1,17 +0,0 @@
|
||||||
name: Spell Check
|
|
||||||
on: [push, pull_request]
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
spellcheck:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- name: Check spelling in specific files
|
|
||||||
uses: codespell-project/actions-codespell@v2
|
|
||||||
with:
|
|
||||||
check_filenames: true
|
|
||||||
only_warn: false
|
|
||||||
path: "**/*.{md, py}"
|
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
name: Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.12"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install .[dev]
|
||||||
|
|
||||||
|
- name: Run tests with pytest
|
||||||
|
run: |
|
||||||
|
python -m pytest tests/ -v
|
||||||
|
|
@ -8,5 +8,8 @@
|
||||||
!*.py
|
!*.py
|
||||||
!*.md
|
!*.md
|
||||||
!*.png
|
!*.png
|
||||||
|
|
||||||
!LICENSE
|
!LICENSE
|
||||||
!pyproject.toml
|
!pyproject.toml
|
||||||
|
!.github/workflows/lint.yml
|
||||||
|
!.github/workflows/tests.yml
|
||||||
16
README.md
16
README.md
|
|
@ -54,8 +54,8 @@ python train.py \
|
||||||
--n_epoch=5 \
|
--n_epoch=5 \
|
||||||
--batch_size=8 \
|
--batch_size=8 \
|
||||||
--max_lr=2e-4 \
|
--max_lr=2e-4 \
|
||||||
--checkpoint_interval=10000 \
|
--ckpt_interval=10000 \
|
||||||
--checkpoint_dir=checkpoints
|
--ckpt_dir=checkpoints
|
||||||
```
|
```
|
||||||
|
|
||||||
**Parameter Explanation:**
|
**Parameter Explanation:**
|
||||||
|
|
@ -67,8 +67,8 @@ python train.py \
|
||||||
- `--accumulation_steps`: Number of batches per training step
|
- `--accumulation_steps`: Number of batches per training step
|
||||||
- `--warmup_steps`: Warmup steps
|
- `--warmup_steps`: Warmup steps
|
||||||
- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
|
- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
|
||||||
- `--checkpoint_interval`: Checkpoint saving interval
|
- `--ckpt_interval`: Checkpoint saving interval
|
||||||
- `--checkpoint_dir`: Checkpoint saving directory
|
- `--ckpt_dir`: Checkpoint saving directory
|
||||||
- `--resume_dir`: Resume training from specified path
|
- `--resume_dir`: Resume training from specified path
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -191,8 +191,8 @@ python train.py \
|
||||||
--n_epoch=5 \
|
--n_epoch=5 \
|
||||||
--batch_size=8 \
|
--batch_size=8 \
|
||||||
--max_lr=2e-4 \
|
--max_lr=2e-4 \
|
||||||
--checkpoint_interval=10000 \
|
--ckpt_interval=10000 \
|
||||||
--checkpoint_dir=checkpoints
|
--ckpt_dir=checkpoints
|
||||||
```
|
```
|
||||||
|
|
||||||
**参数说明:**
|
**参数说明:**
|
||||||
|
|
@ -204,8 +204,8 @@ python train.py \
|
||||||
- `--accumulation_steps`: 每个训练步骤的 batch 数量
|
- `--accumulation_steps`: 每个训练步骤的 batch 数量
|
||||||
- `--warmup_steps`: 预热步数(warmup steps)
|
- `--warmup_steps`: 预热步数(warmup steps)
|
||||||
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
|
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
|
||||||
- `--checkpoint_interval`: 检查点保存间隔
|
- `--ckpt_interval`: 检查点保存间隔
|
||||||
- `--checkpoint_dir`: 检查点保存目录
|
- `--ckpt_dir`: 检查点保存目录
|
||||||
- `--resume_dir`: 从指定路径恢复训练
|
- `--resume_dir`: 从指定路径恢复训练
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,12 @@ import os
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id="ViperEk/KHAOSZ",
|
repo_id="ViperEk/KHAOSZ",
|
||||||
local_dir=os.path.join(PROJECT_ROOT, "params"),
|
local_dir=os.path.join(PROJECT_ROOT, "params"),
|
||||||
force_download=True
|
force_download=True,
|
||||||
)
|
)
|
||||||
|
|
@ -5,8 +5,8 @@ from khaosz.inference.core import disable_random_init
|
||||||
from khaosz.inference.generator import LoopGenerator, GenerationRequest
|
from khaosz.inference.generator import LoopGenerator, GenerationRequest
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
def generate_text():
|
def generate_text():
|
||||||
|
|
||||||
|
|
@ -14,7 +14,7 @@ def generate_text():
|
||||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
param = ModelParameter.load(model_dir)
|
param = ModelParameter.load(model_dir)
|
||||||
|
|
||||||
param.to(device='cuda', dtype=torch.bfloat16)
|
param.to(device="cuda", dtype=torch.bfloat16)
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
|
|
||||||
request = GenerationRequest(
|
request = GenerationRequest(
|
||||||
|
|
@ -31,5 +31,6 @@ def generate_text():
|
||||||
|
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
generate_text()
|
generate_text()
|
||||||
|
|
@ -4,17 +4,23 @@ from khaosz.config.param_config import ModelParameter
|
||||||
from khaosz.inference.core import disable_random_init
|
from khaosz.inference.core import disable_random_init
|
||||||
from khaosz.inference.generator import BatchGenerator, GenerationRequest
|
from khaosz.inference.generator import BatchGenerator, GenerationRequest
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
def batch_generate():
|
def batch_generate():
|
||||||
with disable_random_init():
|
with disable_random_init():
|
||||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
param = ModelParameter.load(model_dir)
|
param = ModelParameter.load(model_dir)
|
||||||
|
|
||||||
param.to(device='cuda', dtype=torch.bfloat16)
|
param.to(device="cuda", dtype=torch.bfloat16)
|
||||||
generator = BatchGenerator(param)
|
generator = BatchGenerator(param)
|
||||||
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
|
inputs = [
|
||||||
|
"你好",
|
||||||
|
"请问什么是人工智能",
|
||||||
|
"今天天气如何",
|
||||||
|
"我感到焦虑, 请问我应该怎么办",
|
||||||
|
"请问什么是显卡",
|
||||||
|
]
|
||||||
|
|
||||||
request = GenerationRequest(
|
request = GenerationRequest(
|
||||||
query=inputs,
|
query=inputs,
|
||||||
|
|
@ -30,5 +36,6 @@ def batch_generate():
|
||||||
for q, r in zip(inputs, responses):
|
for q, r in zip(inputs, responses):
|
||||||
print((q, r))
|
print((q, r))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
batch_generate()
|
batch_generate()
|
||||||
|
|
@ -5,8 +5,8 @@ from khaosz.inference.core import disable_random_init
|
||||||
from khaosz.inference.generator import StreamGenerator, GenerationRequest
|
from khaosz.inference.generator import StreamGenerator, GenerationRequest
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
def chat():
|
def chat():
|
||||||
|
|
||||||
|
|
@ -14,7 +14,7 @@ def chat():
|
||||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
param = ModelParameter.load(model_dir)
|
param = ModelParameter.load(model_dir)
|
||||||
|
|
||||||
param.to(device='cuda', dtype=torch.bfloat16)
|
param.to(device="cuda", dtype=torch.bfloat16)
|
||||||
generator = StreamGenerator(param)
|
generator = StreamGenerator(param)
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
|
|
|
||||||
|
|
@ -6,41 +6,30 @@ from khaosz.config import (
|
||||||
TrainConfig,
|
TrainConfig,
|
||||||
)
|
)
|
||||||
from khaosz.model.transformer import Transformer
|
from khaosz.model.transformer import Transformer
|
||||||
from khaosz.data import (
|
from khaosz.data import DatasetLoader, BpeTokenizer
|
||||||
DatasetLoader,
|
|
||||||
BpeTokenizer
|
|
||||||
)
|
|
||||||
from khaosz.inference.generator import (
|
from khaosz.inference.generator import (
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
LoopGenerator,
|
LoopGenerator,
|
||||||
StreamGenerator,
|
StreamGenerator,
|
||||||
BatchGenerator,
|
BatchGenerator,
|
||||||
EmbeddingEncoder,
|
EmbeddingEncoder,
|
||||||
GeneratorFactory
|
GeneratorFactory,
|
||||||
)
|
|
||||||
from khaosz.trainer import (
|
|
||||||
Trainer,
|
|
||||||
StrategyFactory,
|
|
||||||
SchedulerFactory
|
|
||||||
)
|
)
|
||||||
|
from khaosz.trainer import Trainer, StrategyFactory, SchedulerFactory
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Transformer",
|
"Transformer",
|
||||||
|
|
||||||
"ModelConfig",
|
"ModelConfig",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
|
|
||||||
"DatasetLoader",
|
"DatasetLoader",
|
||||||
"BpeTokenizer",
|
"BpeTokenizer",
|
||||||
|
|
||||||
"GenerationRequest",
|
"GenerationRequest",
|
||||||
"LoopGenerator",
|
"LoopGenerator",
|
||||||
"StreamGenerator",
|
"StreamGenerator",
|
||||||
"BatchGenerator",
|
"BatchGenerator",
|
||||||
"EmbeddingEncoder",
|
"EmbeddingEncoder",
|
||||||
"GeneratorFactory",
|
"GeneratorFactory",
|
||||||
|
|
||||||
"Trainer",
|
"Trainer",
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
"SchedulerFactory"
|
"SchedulerFactory",
|
||||||
]
|
]
|
||||||
|
|
@ -4,7 +4,7 @@ from khaosz.config.schedule_config import (
|
||||||
ScheduleConfig,
|
ScheduleConfig,
|
||||||
CosineScheduleConfig,
|
CosineScheduleConfig,
|
||||||
SGDRScheduleConfig,
|
SGDRScheduleConfig,
|
||||||
ScheduleConfigFactory
|
ScheduleConfigFactory,
|
||||||
)
|
)
|
||||||
from khaosz.config.train_config import TrainConfig
|
from khaosz.config.train_config import TrainConfig
|
||||||
|
|
||||||
|
|
@ -13,11 +13,9 @@ __all__ = [
|
||||||
# Base I/O
|
# Base I/O
|
||||||
"BaseModelIO",
|
"BaseModelIO",
|
||||||
"ModelParameter",
|
"ModelParameter",
|
||||||
|
|
||||||
# Model configuration
|
# Model configuration
|
||||||
"ModelConfig",
|
"ModelConfig",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
|
|
||||||
# Schedule configuration
|
# Schedule configuration
|
||||||
"ScheduleConfig",
|
"ScheduleConfig",
|
||||||
"CosineScheduleConfig",
|
"CosineScheduleConfig",
|
||||||
|
|
|
||||||
|
|
@ -25,10 +25,9 @@ class ModelConfig:
|
||||||
use_qk_norm: Optional[bool] = None
|
use_qk_norm: Optional[bool] = None
|
||||||
use_gated_attention: Optional[bool] = None
|
use_gated_attention: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
def load(self, config_path: str) -> Self:
|
def load(self, config_path: str) -> Self:
|
||||||
config = {}
|
config = {}
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, "r") as f:
|
||||||
config.update(json.load(f))
|
config.update(json.load(f))
|
||||||
|
|
||||||
for key, value in config.items():
|
for key, value in config.items():
|
||||||
|
|
@ -39,5 +38,5 @@ class ModelConfig:
|
||||||
|
|
||||||
def save(self, config_path: str):
|
def save(self, config_path: str):
|
||||||
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
|
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_dict, f, indent=4)
|
json.dump(config_dict, f, indent=4)
|
||||||
|
|
|
||||||
|
|
@ -9,21 +9,20 @@ from khaosz.data.tokenizer import BpeTokenizer
|
||||||
from khaosz.config.model_config import ModelConfig
|
from khaosz.config.model_config import ModelConfig
|
||||||
from khaosz.model.transformer import Transformer
|
from khaosz.model.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseModelIO:
|
class BaseModelIO:
|
||||||
"""Base class for model I/O operations."""
|
"""Base class for model I/O operations."""
|
||||||
|
|
||||||
model: Optional[nn.Module] = field(
|
model: Optional[nn.Module] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Transformer model."}
|
||||||
metadata={"help": "Transformer model."}
|
|
||||||
)
|
)
|
||||||
tokenizer: BpeTokenizer = field(
|
tokenizer: BpeTokenizer = field(
|
||||||
default_factory=BpeTokenizer,
|
default_factory=BpeTokenizer, metadata={"help": "Tokenizer for the model."}
|
||||||
metadata={"help": "Tokenizer for the model."}
|
|
||||||
)
|
)
|
||||||
config: ModelConfig = field(
|
config: ModelConfig = field(
|
||||||
default_factory=ModelConfig,
|
default_factory=ModelConfig,
|
||||||
metadata={"help": "Transformer model configuration."}
|
metadata={"help": "Transformer model configuration."},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
||||||
|
|
@ -32,7 +31,7 @@ class BaseModelIO:
|
||||||
return {
|
return {
|
||||||
"model": dir_path / "model.safetensors",
|
"model": dir_path / "model.safetensors",
|
||||||
"config": dir_path / "config.json",
|
"config": dir_path / "config.json",
|
||||||
"tokenizer": dir_path / "tokenizer.json"
|
"tokenizer": dir_path / "tokenizer.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
def save_components(self, save_dir: Union[str, Path]):
|
def save_components(self, save_dir: Union[str, Path]):
|
||||||
|
|
@ -80,4 +79,3 @@ class ModelParameter(BaseModelIO):
|
||||||
def load(cls, load_dir: Union[str, Path]) -> "ModelParameter":
|
def load(cls, load_dir: Union[str, Path]) -> "ModelParameter":
|
||||||
instance = cls()
|
instance = cls()
|
||||||
return instance.load_components(load_dir)
|
return instance.load_components(load_dir)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,16 +14,14 @@ class ScheduleConfig(ABC):
|
||||||
default="cosine",
|
default="cosine",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Type of learning rate schedule.",
|
"help": "Type of learning rate schedule.",
|
||||||
"choices": ["cosine", "sgdr"]
|
"choices": ["cosine", "sgdr"],
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
warmup_steps: int = field(
|
warmup_steps: int = field(
|
||||||
default=1000,
|
default=1000, metadata={"help": "Number of warmup steps."}
|
||||||
metadata={"help": "Number of warmup steps."}
|
|
||||||
)
|
)
|
||||||
min_rate: float = field(
|
min_rate: float = field(
|
||||||
default=0.05,
|
default=0.05, metadata={"help": "Minimum learning rate multiplier."}
|
||||||
metadata={"help": "Minimum learning rate multiplier."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
@ -34,7 +32,9 @@ class ScheduleConfig(ABC):
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
"""Validate configuration parameters."""
|
"""Validate configuration parameters."""
|
||||||
if self.warmup_steps < 0:
|
if self.warmup_steps < 0:
|
||||||
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
|
raise ValueError(
|
||||||
|
f"warmup_steps must be non-negative, got {self.warmup_steps}"
|
||||||
|
)
|
||||||
if not 0 <= self.min_rate <= 1:
|
if not 0 <= self.min_rate <= 1:
|
||||||
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
|
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
|
||||||
|
|
||||||
|
|
@ -44,8 +44,7 @@ class CosineScheduleConfig(ScheduleConfig):
|
||||||
"""Cosine annealing learning rate schedule configuration."""
|
"""Cosine annealing learning rate schedule configuration."""
|
||||||
|
|
||||||
total_steps: int = field(
|
total_steps: int = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Total training steps for cosine schedule."}
|
||||||
metadata={"help": "Total training steps for cosine schedule."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
|
@ -60,13 +59,15 @@ class CosineScheduleConfig(ScheduleConfig):
|
||||||
"schedule_type": self.schedule_type,
|
"schedule_type": self.schedule_type,
|
||||||
"warmup_steps": self.warmup_steps,
|
"warmup_steps": self.warmup_steps,
|
||||||
"lr_decay_steps": self.total_steps - self.warmup_steps,
|
"lr_decay_steps": self.total_steps - self.warmup_steps,
|
||||||
"min_rate": self.min_rate
|
"min_rate": self.min_rate,
|
||||||
}
|
}
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
super().validate()
|
super().validate()
|
||||||
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
|
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
|
||||||
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
|
raise ValueError(
|
||||||
|
f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -74,12 +75,10 @@ class SGDRScheduleConfig(ScheduleConfig):
|
||||||
"""Stochastic Gradient Descent with Warm Restarts schedule configuration."""
|
"""Stochastic Gradient Descent with Warm Restarts schedule configuration."""
|
||||||
|
|
||||||
cycle_length: int = field(
|
cycle_length: int = field(
|
||||||
default=1000,
|
default=1000, metadata={"help": "Length of the first cycle in steps."}
|
||||||
metadata={"help": "Length of the first cycle in steps."}
|
|
||||||
)
|
)
|
||||||
t_mult: int = field(
|
t_mult: int = field(
|
||||||
default=2,
|
default=2, metadata={"help": "Multiplier for cycle length growth."}
|
||||||
metadata={"help": "Multiplier for cycle length growth."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
|
@ -92,7 +91,7 @@ class SGDRScheduleConfig(ScheduleConfig):
|
||||||
"warmup_steps": self.warmup_steps,
|
"warmup_steps": self.warmup_steps,
|
||||||
"cycle_length": self.cycle_length,
|
"cycle_length": self.cycle_length,
|
||||||
"min_rate": self.min_rate,
|
"min_rate": self.min_rate,
|
||||||
"t_mult": self.t_mult
|
"t_mult": self.t_mult,
|
||||||
}
|
}
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
|
|
|
||||||
|
|
@ -10,127 +10,92 @@ from typing import Callable, List, Optional
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainConfig:
|
class TrainConfig:
|
||||||
# basic setting
|
# basic setting
|
||||||
model: nn.Module = field(
|
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
|
||||||
default=None,
|
strategy: str = field(default=None, metadata={"help": "Training strategy."})
|
||||||
metadata={"help": "Model for training."}
|
dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."})
|
||||||
)
|
|
||||||
strategy: str = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Training strategy."}
|
|
||||||
)
|
|
||||||
dataset: Dataset = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Dataset for training."}
|
|
||||||
)
|
|
||||||
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Optimizer factory for training."}
|
||||||
metadata={"help": "Optimizer factory for training."}
|
|
||||||
)
|
)
|
||||||
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Scheduler factory for training."}
|
||||||
metadata={"help": "Scheduler factory for training."}
|
|
||||||
)
|
|
||||||
n_epoch: int = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "Number of epochs for training."}
|
|
||||||
)
|
|
||||||
batch_size: int = field(
|
|
||||||
default=4,
|
|
||||||
metadata={"help": "Batch size for training."}
|
|
||||||
)
|
)
|
||||||
|
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
||||||
|
batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
|
||||||
accumulation_steps: int = field(
|
accumulation_steps: int = field(
|
||||||
default=1,
|
default=1, metadata={"help": "Number of iterations between steps."}
|
||||||
metadata={"help": "Number of iterations between steps."}
|
|
||||||
)
|
)
|
||||||
max_grad_norm: float = field(
|
max_grad_norm: float = field(
|
||||||
default=1.0,
|
default=1.0, metadata={"help": "Maximum gradient norm."}
|
||||||
metadata={"help": "Maximum gradient norm."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# checkpoint setting
|
# checkpoint setting
|
||||||
start_epoch: int = field(
|
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
|
||||||
default=0,
|
|
||||||
metadata={"help": "Start epoch for training."}
|
|
||||||
)
|
|
||||||
start_batch: int = field(
|
start_batch: int = field(
|
||||||
default=0,
|
default=0, metadata={"help": "Start batch iteration for training."}
|
||||||
metadata={"help": "Start batch iteration for training."}
|
|
||||||
)
|
)
|
||||||
checkpoint_dir: str = field(
|
ckpt_dir: str = field(
|
||||||
default="./checkpoint",
|
default="./checkpoint", metadata={"help": "Checkpoint directory."}
|
||||||
metadata={"help": "Checkpoint directory."}
|
|
||||||
)
|
)
|
||||||
checkpoint_interval: int = field(
|
ckpt_interval: int = field(
|
||||||
default=5000,
|
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
||||||
metadata={"help": "Number of iterations between checkpoints."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# dataloader setting
|
# dataloader setting
|
||||||
random_seed: int = field(
|
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
|
||||||
default=3407,
|
|
||||||
metadata={"help": "Random seed."}
|
|
||||||
)
|
|
||||||
num_workers: int = field(
|
num_workers: int = field(
|
||||||
default=0,
|
default=0, metadata={"help": "Number of workers for dataloader."}
|
||||||
metadata={"help": "Number of workers for dataloader."}
|
|
||||||
)
|
)
|
||||||
prefetch_factor: Optional[int] = field(
|
prefetch_factor: Optional[int] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Prefetch factor for dataloader."}
|
||||||
metadata={"help": "Prefetch factor for dataloader."}
|
|
||||||
)
|
)
|
||||||
pin_memory: bool = field(
|
pin_memory: bool = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Pin memory for dataloader."}
|
||||||
metadata={"help": "Pin memory for dataloader."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# distributed training
|
# distributed training
|
||||||
nprocs: int = field(
|
nprocs: int = field(
|
||||||
default=1,
|
default=1, metadata={"help": "Number of processes for distributed training."}
|
||||||
metadata={"help": "Number of processes for distributed training."}
|
|
||||||
)
|
)
|
||||||
backend: str = field(
|
backend: str = field(
|
||||||
default="nccl",
|
default="nccl", metadata={"help": "Distributed training backend."}
|
||||||
metadata={"help": "Distributed training backend."}
|
|
||||||
)
|
)
|
||||||
master_addr: str = field(
|
master_addr: str = field(
|
||||||
default="localhost",
|
default="localhost",
|
||||||
metadata={"help": "Master address for distributed training."}
|
metadata={"help": "Master address for distributed training."},
|
||||||
)
|
)
|
||||||
master_port: str = field(
|
master_port: str = field(
|
||||||
default="29500",
|
default="29500", metadata={"help": "Master port for distributed training."}
|
||||||
metadata={"help": "Master port for distributed training."}
|
|
||||||
)
|
)
|
||||||
parallel_wrapper: Optional[Callable] = field(
|
parallel_wrapper: Optional[Callable] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Parallel function for training."}
|
||||||
metadata={"help": "Parallel function for training."}
|
|
||||||
)
|
)
|
||||||
state_dict_fn: Optional[Callable] = field(
|
state_dict_fn: Optional[Callable] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Parallel function for state dict saving."}
|
||||||
metadata={"help": "Parallel function for state dict saving."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# others
|
# others
|
||||||
device_ids: Optional[List[int]] = field(
|
device_ids: Optional[List[int]] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "Device ids for distributed training."}
|
||||||
metadata={"help": "Device ids for distributed training."}
|
|
||||||
)
|
)
|
||||||
device_type: str = field(
|
device_type: str = field(
|
||||||
default="cuda",
|
default="cuda", metadata={"help": "Device type for distributed training."}
|
||||||
metadata={"help": "Device type for distributed training."}
|
|
||||||
)
|
)
|
||||||
extra_kwargs: dict = field(
|
extra_kwargs: dict = field(
|
||||||
default_factory=dict,
|
default_factory=dict, metadata={"help": "Other arguments."}
|
||||||
metadata={"help": "Other arguments."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.validate()
|
self.validate()
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
required_fields = ["model", "strategy", "dataset", "optimizer_fn", "scheduler_fn"]
|
required_fields = [
|
||||||
|
"model",
|
||||||
|
"strategy",
|
||||||
|
"dataset",
|
||||||
|
"optimizer_fn",
|
||||||
|
"scheduler_fn",
|
||||||
|
]
|
||||||
|
|
||||||
for field_name in required_fields:
|
for field_name in required_fields:
|
||||||
if getattr(self, field_name) is None:
|
if getattr(self, field_name) is None:
|
||||||
raise ValueError(f"{field_name} is required.")
|
raise ValueError(f"{field_name} is required.")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -6,7 +6,7 @@ from khaosz.data.dataset import (
|
||||||
GRPODataset,
|
GRPODataset,
|
||||||
MultiSegmentFetcher,
|
MultiSegmentFetcher,
|
||||||
DatasetLoader,
|
DatasetLoader,
|
||||||
DatasetFactory
|
DatasetFactory,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khaosz.data.tokenizer import BpeTokenizer
|
from khaosz.data.tokenizer import BpeTokenizer
|
||||||
|
|
@ -15,21 +15,17 @@ from khaosz.data.sampler import ResumableDistributedSampler
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Base classes
|
# Base classes
|
||||||
"BaseDataset",
|
"BaseDataset",
|
||||||
|
|
||||||
# Dataset implementations
|
# Dataset implementations
|
||||||
"SEQDataset",
|
"SEQDataset",
|
||||||
"SFTDataset",
|
"SFTDataset",
|
||||||
"DPODataset",
|
"DPODataset",
|
||||||
"GRPODataset",
|
"GRPODataset",
|
||||||
|
|
||||||
# Fetchers
|
# Fetchers
|
||||||
"MultiSegmentFetcher",
|
"MultiSegmentFetcher",
|
||||||
|
|
||||||
# Factory (DatasetLoader is alias for backward compatibility)
|
# Factory (DatasetLoader is alias for backward compatibility)
|
||||||
"DatasetLoader",
|
"DatasetLoader",
|
||||||
"DatasetFactory",
|
"DatasetFactory",
|
||||||
|
|
||||||
# Tokenizer and sampler
|
# Tokenizer and sampler
|
||||||
"BpeTokenizer",
|
"BpeTokenizer",
|
||||||
"ResumableDistributedSampler"
|
"ResumableDistributedSampler",
|
||||||
]
|
]
|
||||||
|
|
@ -41,7 +41,9 @@ class BaseSegmentFetcher:
|
||||||
Returns:
|
Returns:
|
||||||
Concatenated tensor of data in the specified range
|
Concatenated tensor of data in the specified range
|
||||||
"""
|
"""
|
||||||
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
|
if not (
|
||||||
|
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
|
||||||
|
):
|
||||||
raise ValueError("begin_idx or end_idx out of bounds")
|
raise ValueError("begin_idx or end_idx out of bounds")
|
||||||
if begin_idx >= end_idx:
|
if begin_idx >= end_idx:
|
||||||
return torch.tensor([], dtype=torch.long)
|
return torch.tensor([], dtype=torch.long)
|
||||||
|
|
@ -71,8 +73,7 @@ class MultiSegmentFetcher:
|
||||||
def __init__(self, muti_segments: Dict):
|
def __init__(self, muti_segments: Dict):
|
||||||
self.muti_keys = list(muti_segments.keys())
|
self.muti_keys = list(muti_segments.keys())
|
||||||
self.muti_fetchers = {
|
self.muti_fetchers = {
|
||||||
key: BaseSegmentFetcher(segments)
|
key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items()
|
||||||
for key, segments in muti_segments.items()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
|
@ -80,7 +81,9 @@ class MultiSegmentFetcher:
|
||||||
len_list = [len(seg) for seg in self.muti_fetchers.values()]
|
len_list = [len(seg) for seg in self.muti_fetchers.values()]
|
||||||
return min(len_list)
|
return min(len_list)
|
||||||
|
|
||||||
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
|
def key_fetch(
|
||||||
|
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
|
||||||
|
) -> Dict:
|
||||||
"""Fetch data for specific keys.
|
"""Fetch data for specific keys.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -189,11 +192,13 @@ class DatasetFactory:
|
||||||
Returns:
|
Returns:
|
||||||
Decorator function that registers the dataset class
|
Decorator function that registers the dataset class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(dataset_cls: type) -> type:
|
def decorator(dataset_cls: type) -> type:
|
||||||
if not issubclass(dataset_cls, BaseDataset):
|
if not issubclass(dataset_cls, BaseDataset):
|
||||||
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
||||||
cls.DATASET_MAP[name] = dataset_cls
|
cls.DATASET_MAP[name] = dataset_cls
|
||||||
return dataset_cls
|
return dataset_cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -223,7 +228,13 @@ class DatasetFactory:
|
||||||
return dataset_cls(window_size, stride)
|
return dataset_cls(window_size, stride)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, train_type: str, load_path: str, window_size: int, stride: Optional[int] = None) -> BaseDataset:
|
def load(
|
||||||
|
cls,
|
||||||
|
train_type: str,
|
||||||
|
load_path: str,
|
||||||
|
window_size: int,
|
||||||
|
stride: Optional[int] = None,
|
||||||
|
) -> BaseDataset:
|
||||||
"""Create and load a dataset in one step.
|
"""Create and load a dataset in one step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -286,8 +297,12 @@ class SFTDataset(BaseDataset):
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
|
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
|
||||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool)
|
dtype=torch.long
|
||||||
|
)
|
||||||
|
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
|
||||||
|
dtype=torch.bool
|
||||||
|
)
|
||||||
|
|
||||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
||||||
|
|
||||||
|
|
@ -307,10 +322,19 @@ class DPODataset(BaseDataset):
|
||||||
|
|
||||||
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
|
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
|
||||||
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
|
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
|
||||||
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool)
|
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(
|
||||||
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool)
|
dtype=torch.bool
|
||||||
|
)
|
||||||
|
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(
|
||||||
|
dtype=torch.bool
|
||||||
|
)
|
||||||
|
|
||||||
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
return {
|
||||||
|
"chosen": chosen,
|
||||||
|
"rejected": rejected,
|
||||||
|
"chosen_mask": chosen_mask,
|
||||||
|
"rejected_mask": rejected_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@DatasetFactory.register("grpo")
|
@DatasetFactory.register("grpo")
|
||||||
|
|
@ -331,7 +355,12 @@ class GRPODataset(BaseDataset):
|
||||||
masks = self._fetch_data(begin_idx, end_idx, "masks")
|
masks = self._fetch_data(begin_idx, end_idx, "masks")
|
||||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||||
|
|
||||||
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
|
return {
|
||||||
|
"prompts": prompts,
|
||||||
|
"responses": responses,
|
||||||
|
"masks": masks,
|
||||||
|
"rewards": rewards,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias
|
# Backward compatibility alias
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,12 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_source: Dataset,
|
data_source: Dataset,
|
||||||
start_epoch: int=0,
|
start_epoch: int = 0,
|
||||||
start_iter: int=0,
|
start_iter: int = 0,
|
||||||
seed: int=42,
|
seed: int = 42,
|
||||||
drop_last: bool=False,
|
drop_last: bool = False,
|
||||||
shuffle: bool=True,
|
shuffle: bool = True,
|
||||||
process_group: Optional[dist.ProcessGroup]=None,
|
process_group: Optional[dist.ProcessGroup] = None,
|
||||||
):
|
):
|
||||||
self.epoch = start_epoch
|
self.epoch = start_epoch
|
||||||
self.iter = start_iter
|
self.iter = start_iter
|
||||||
|
|
@ -40,7 +40,7 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
self.drop_last = drop_last
|
self.drop_last = drop_last
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
|
|
||||||
offset = 0 if drop_last else self.num_replicas - 1
|
offset = 0 if drop_last else self.num_replicas - 1
|
||||||
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
||||||
self.total_size = self.num_samples_per_replica * self.num_replicas
|
self.total_size = self.num_samples_per_replica * self.num_replicas
|
||||||
|
|
||||||
|
|
@ -58,10 +58,10 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
padding_size = self.total_size - len(indices)
|
padding_size = self.total_size - len(indices)
|
||||||
indices += indices[:padding_size]
|
indices += indices[:padding_size]
|
||||||
|
|
||||||
local_indices = indices[self.rank:self.total_size:self.num_replicas]
|
local_indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||||
|
|
||||||
self.iter = self.iter % self.num_samples_per_replica
|
self.iter = self.iter % self.num_samples_per_replica
|
||||||
self._indices = local_indices[self.iter:]
|
self._indices = local_indices[self.iter :]
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
if self._indices is None:
|
if self._indices is None:
|
||||||
|
|
|
||||||
|
|
@ -10,15 +10,17 @@ from torch import Tensor
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
from khaosz.parallel.setup import get_rank
|
from khaosz.parallel.setup import get_rank
|
||||||
|
|
||||||
|
|
||||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
||||||
with h5py.File(full_file_path, 'w') as f:
|
with h5py.File(full_file_path, "w") as f:
|
||||||
for key, tensors in tensor_group.items():
|
for key, tensors in tensor_group.items():
|
||||||
grp = f.create_group(key)
|
grp = f.create_group(key)
|
||||||
for idx, tensor in enumerate(tensors):
|
for idx, tensor in enumerate(tensors):
|
||||||
arr = tensor.cpu().numpy()
|
arr = tensor.cpu().numpy()
|
||||||
grp.create_dataset(f'data_{idx}', data=arr)
|
grp.create_dataset(f"data_{idx}", data=arr)
|
||||||
|
|
||||||
|
|
||||||
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||||
tensor_group: Dict[str, List[Tensor]] = {}
|
tensor_group: Dict[str, List[Tensor]] = {}
|
||||||
|
|
@ -27,7 +29,7 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||||
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||||
|
|
||||||
for h5_file in h5_files:
|
for h5_file in h5_files:
|
||||||
with h5py.File(h5_file, 'r') as f:
|
with h5py.File(h5_file, "r") as f:
|
||||||
for key in f.keys():
|
for key in f.keys():
|
||||||
grp = f[key]
|
grp = f[key]
|
||||||
dsets = []
|
dsets = []
|
||||||
|
|
|
||||||
|
|
@ -12,15 +12,16 @@ class BpeTokenizer:
|
||||||
|
|
||||||
model = BPE()
|
model = BPE()
|
||||||
self._tokenizer = Tokenizer(model)
|
self._tokenizer = Tokenizer(model)
|
||||||
self._tokenizer.normalizer = normalizers.Sequence([
|
self._tokenizer.normalizer = normalizers.Sequence(
|
||||||
normalizers.NFC(),
|
[normalizers.NFC(), normalizers.Strip()]
|
||||||
normalizers.Strip()
|
)
|
||||||
])
|
|
||||||
|
|
||||||
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||||
pre_tokenizers.UnicodeScripts(),
|
[
|
||||||
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True)
|
pre_tokenizers.UnicodeScripts(),
|
||||||
])
|
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
self._tokenizer.decoder = decoders.ByteLevel()
|
self._tokenizer.decoder = decoders.ByteLevel()
|
||||||
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
|
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
|
||||||
|
|
@ -28,10 +29,21 @@ class BpeTokenizer:
|
||||||
if path is not None:
|
if path is not None:
|
||||||
self._tokenizer = Tokenizer.from_file(path)
|
self._tokenizer = Tokenizer.from_file(path)
|
||||||
|
|
||||||
def _prepare_trainer(self, vocab_size: int, min_freq: int, reserved_token_size: int, max_token_length=18) -> tuple:
|
def _prepare_trainer(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
min_freq: int,
|
||||||
|
reserved_token_size: int,
|
||||||
|
max_token_length=18,
|
||||||
|
) -> tuple:
|
||||||
assert reserved_token_size > len(self._special_tokens)
|
assert reserved_token_size > len(self._special_tokens)
|
||||||
reserved_tokens = [f"<|reserve{i:02d}|>" for i in range(reserved_token_size - len(self._special_tokens))]
|
reserved_tokens = [
|
||||||
detail_vocab_size = vocab_size - (len(reserved_tokens) + len(self._special_tokens))
|
f"<|reserve{i:02d}|>"
|
||||||
|
for i in range(reserved_token_size - len(self._special_tokens))
|
||||||
|
]
|
||||||
|
detail_vocab_size = vocab_size - (
|
||||||
|
len(reserved_tokens) + len(self._special_tokens)
|
||||||
|
)
|
||||||
|
|
||||||
alphabet = pre_tokenizers.ByteLevel.alphabet()
|
alphabet = pre_tokenizers.ByteLevel.alphabet()
|
||||||
min_size = len(alphabet) + len(self._control_tokens)
|
min_size = len(alphabet) + len(self._control_tokens)
|
||||||
|
|
@ -53,16 +65,18 @@ class BpeTokenizer:
|
||||||
trainer, _, reserved_tokens = self._prepare_trainer(
|
trainer, _, reserved_tokens = self._prepare_trainer(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
min_freq=min_freq,
|
min_freq=min_freq,
|
||||||
reserved_token_size=reserved_token_size
|
reserved_token_size=reserved_token_size,
|
||||||
)
|
)
|
||||||
self._tokenizer.train(files=files, trainer=trainer)
|
self._tokenizer.train(files=files, trainer=trainer)
|
||||||
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
|
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
|
||||||
|
|
||||||
def train_from_iterator(self, iterator, vocab_size, min_freq, reserved_token_size=100):
|
def train_from_iterator(
|
||||||
|
self, iterator, vocab_size, min_freq, reserved_token_size=100
|
||||||
|
):
|
||||||
trainer, _, reserved_tokens = self._prepare_trainer(
|
trainer, _, reserved_tokens = self._prepare_trainer(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
min_freq=min_freq,
|
min_freq=min_freq,
|
||||||
reserved_token_size=reserved_token_size
|
reserved_token_size=reserved_token_size,
|
||||||
)
|
)
|
||||||
self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer)
|
self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer)
|
||||||
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
|
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
|
||||||
|
|
@ -73,15 +87,26 @@ class BpeTokenizer:
|
||||||
def load(self, path):
|
def load(self, path):
|
||||||
self._tokenizer = Tokenizer.from_file(path)
|
self._tokenizer = Tokenizer.from_file(path)
|
||||||
|
|
||||||
def encode(self, tokens: Union[str, List[str]], out_ids: bool=True, add_special_tokens: bool=False) -> List:
|
def encode(
|
||||||
|
self,
|
||||||
|
tokens: Union[str, List[str]],
|
||||||
|
out_ids: bool = True,
|
||||||
|
add_special_tokens: bool = False,
|
||||||
|
) -> List:
|
||||||
if isinstance(tokens, str):
|
if isinstance(tokens, str):
|
||||||
encoded: Encoding = self._tokenizer.encode(tokens, add_special_tokens=add_special_tokens)
|
encoded: Encoding = self._tokenizer.encode(
|
||||||
|
tokens, add_special_tokens=add_special_tokens
|
||||||
|
)
|
||||||
return encoded.ids if out_ids else encoded.tokens
|
return encoded.ids if out_ids else encoded.tokens
|
||||||
elif isinstance(tokens, list):
|
elif isinstance(tokens, list):
|
||||||
encoded_list: List[Encoding] = self._tokenizer.encode_batch(tokens, add_special_tokens=add_special_tokens)
|
encoded_list: List[Encoding] = self._tokenizer.encode_batch(
|
||||||
return [encoded.ids if out_ids else encoded.tokens for encoded in encoded_list]
|
tokens, add_special_tokens=add_special_tokens
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
|
||||||
|
]
|
||||||
|
|
||||||
def decode(self, tokens: List[int], skip_special_tokens: bool=True) -> str:
|
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
|
||||||
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from khaosz.inference.generator import (
|
||||||
StreamGenerator,
|
StreamGenerator,
|
||||||
BatchGenerator,
|
BatchGenerator,
|
||||||
EmbeddingEncoder,
|
EmbeddingEncoder,
|
||||||
GeneratorFactory
|
GeneratorFactory,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -19,11 +19,10 @@ __all__ = [
|
||||||
"GeneratorCore",
|
"GeneratorCore",
|
||||||
"EmbeddingEncoderCore",
|
"EmbeddingEncoderCore",
|
||||||
"KVCacheManager",
|
"KVCacheManager",
|
||||||
|
|
||||||
"GenerationRequest",
|
"GenerationRequest",
|
||||||
"LoopGenerator",
|
"LoopGenerator",
|
||||||
"StreamGenerator",
|
"StreamGenerator",
|
||||||
"BatchGenerator",
|
"BatchGenerator",
|
||||||
"EmbeddingEncoder",
|
"EmbeddingEncoder",
|
||||||
"GeneratorFactory"
|
"GeneratorFactory",
|
||||||
]
|
]
|
||||||
|
|
@ -12,7 +12,7 @@ def apply_sampling_strategies(
|
||||||
temperature: float,
|
temperature: float,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
filter_value: float = -float("inf")
|
filter_value: float = -float("inf"),
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Apply sampling strategies to the logits tensor.
|
Apply sampling strategies to the logits tensor.
|
||||||
|
|
@ -47,9 +47,7 @@ def apply_sampling_strategies(
|
||||||
|
|
||||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||||
indices_to_remove.scatter_(
|
indices_to_remove.scatter_(
|
||||||
dim=1,
|
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||||
index=sorted_indices,
|
|
||||||
src=sorted_indices_to_remove
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logits[indices_to_remove] = filter_value
|
logits[indices_to_remove] = filter_value
|
||||||
|
|
@ -60,10 +58,15 @@ def apply_sampling_strategies(
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def disable_random_init():
|
def disable_random_init():
|
||||||
init_functions = [
|
init_functions = [
|
||||||
'xavier_normal_', 'xavier_uniform_',
|
"xavier_normal_",
|
||||||
'kaiming_normal_', 'kaiming_uniform_',
|
"xavier_uniform_",
|
||||||
'zeros_', 'ones_', 'constant_',
|
"kaiming_normal_",
|
||||||
'normal_', 'uniform_'
|
"kaiming_uniform_",
|
||||||
|
"zeros_",
|
||||||
|
"ones_",
|
||||||
|
"constant_",
|
||||||
|
"normal_",
|
||||||
|
"uniform_",
|
||||||
]
|
]
|
||||||
original_funcs = {}
|
original_funcs = {}
|
||||||
for name in init_functions:
|
for name in init_functions:
|
||||||
|
|
@ -91,8 +94,8 @@ class GeneratorCore:
|
||||||
top_p: float,
|
top_p: float,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
||||||
start_pos: int = 0
|
start_pos: int = 0,
|
||||||
)-> Tuple[Tensor, int]:
|
) -> Tuple[Tensor, int]:
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
|
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
|
||||||
|
|
@ -115,13 +118,20 @@ class GeneratorCore:
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
||||||
start_pos: int = 0,
|
start_pos: int = 0,
|
||||||
callback: Optional[Callable[..., Any]] = None
|
callback: Optional[Callable[..., Any]] = None,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
cur_cache_pos = start_pos
|
cur_cache_pos = start_pos
|
||||||
|
|
||||||
for _ in range(len(ids), self.config.max_len):
|
for _ in range(len(ids), self.config.max_len):
|
||||||
next_token_id, cache_increase = self.generate_iterator(
|
next_token_id, cache_increase = self.generate_iterator(
|
||||||
input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos)
|
input_ids,
|
||||||
|
temperature,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
attn_mask,
|
||||||
|
kv_caches,
|
||||||
|
cur_cache_pos,
|
||||||
|
)
|
||||||
|
|
||||||
input_ids = next_token_id
|
input_ids = next_token_id
|
||||||
ids.append(next_token_id.item())
|
ids.append(next_token_id.item())
|
||||||
|
|
@ -157,14 +167,17 @@ class EmbeddingEncoderCore:
|
||||||
|
|
||||||
for i, seq in enumerate(batch_ids):
|
for i, seq in enumerate(batch_ids):
|
||||||
if len(seq) > max_model_len:
|
if len(seq) > max_model_len:
|
||||||
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
|
fragments = [
|
||||||
|
seq[j : j + max_model_len]
|
||||||
|
for j in range(0, len(seq), max_model_len)
|
||||||
|
]
|
||||||
all_fragments.extend(fragments)
|
all_fragments.extend(fragments)
|
||||||
fragment_origin_idx.extend([i] * len(fragments))
|
fragment_origin_idx.extend([i] * len(fragments))
|
||||||
else:
|
else:
|
||||||
all_fragments.append(seq)
|
all_fragments.append(seq)
|
||||||
fragment_origin_idx.append(i)
|
fragment_origin_idx.append(i)
|
||||||
|
|
||||||
#if empty fragments
|
# if empty fragments
|
||||||
if not all_fragments or not ids:
|
if not all_fragments or not ids:
|
||||||
return [] if with_batch else torch.tensor([])
|
return [] if with_batch else torch.tensor([])
|
||||||
|
|
||||||
|
|
@ -190,11 +203,17 @@ class EmbeddingEncoderCore:
|
||||||
|
|
||||||
sentence_embs: List[Tensor] = []
|
sentence_embs: List[Tensor] = []
|
||||||
for i in range(len(batch_ids)):
|
for i in range(len(batch_ids)):
|
||||||
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
|
indices = [
|
||||||
|
idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i
|
||||||
|
]
|
||||||
if indices:
|
if indices:
|
||||||
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
|
sum_frags = torch.sum(
|
||||||
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
|
fragment_embs[indices, :, :], dim=1
|
||||||
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
|
) # [frags, hidden_size]
|
||||||
|
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(
|
||||||
|
1
|
||||||
|
) # [frags, 1]
|
||||||
|
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
|
||||||
sentence_embs.append(emb.flatten())
|
sentence_embs.append(emb.flatten())
|
||||||
|
|
||||||
if with_batch:
|
if with_batch:
|
||||||
|
|
@ -213,7 +232,7 @@ class KVCacheManager:
|
||||||
config: ModelConfig,
|
config: ModelConfig,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
device: torch.device = "cuda",
|
device: torch.device = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
):
|
):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
@ -221,7 +240,7 @@ class KVCacheManager:
|
||||||
self.num_layers = config.n_layers
|
self.num_layers = config.n_layers
|
||||||
self.max_len = config.max_len
|
self.max_len = config.max_len
|
||||||
self.num_heads = config.n_kv_heads
|
self.num_heads = config.n_kv_heads
|
||||||
self.head_dim = config.dim //config.n_heads
|
self.head_dim = config.dim // config.n_heads
|
||||||
|
|
||||||
self._kv_cache: Tuple[Tensor, Tensor] = None
|
self._kv_cache: Tuple[Tensor, Tensor] = None
|
||||||
self._seq_mask: Tensor = None
|
self._seq_mask: Tensor = None
|
||||||
|
|
@ -229,15 +248,31 @@ class KVCacheManager:
|
||||||
|
|
||||||
def _initialize(self):
|
def _initialize(self):
|
||||||
k_cache = torch.empty(
|
k_cache = torch.empty(
|
||||||
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
|
(
|
||||||
device=self.device, dtype=self.dtype
|
self.batch_size,
|
||||||
|
self.max_len,
|
||||||
|
self.num_layers,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
),
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
v_cache = torch.empty(
|
v_cache = torch.empty(
|
||||||
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
|
(
|
||||||
device=self.device, dtype=self.dtype
|
self.batch_size,
|
||||||
|
self.max_len,
|
||||||
|
self.num_layers,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
),
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self._kv_cache = (k_cache, v_cache)
|
self._kv_cache = (k_cache, v_cache)
|
||||||
self._seq_mask = torch.ones((self.batch_size, self.max_len), device=self.device, dtype=torch.bool)
|
self._seq_mask = torch.ones(
|
||||||
|
(self.batch_size, self.max_len), device=self.device, dtype=torch.bool
|
||||||
|
)
|
||||||
|
|
||||||
def update(self, active_mask: Tensor):
|
def update(self, active_mask: Tensor):
|
||||||
k_cache, v_cache = self._kv_cache
|
k_cache, v_cache = self._kv_cache
|
||||||
|
|
@ -253,8 +288,8 @@ class KVCacheManager:
|
||||||
|
|
||||||
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
|
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
|
||||||
batch_size, seq_len = input_ids.shape
|
batch_size, seq_len = input_ids.shape
|
||||||
bool_mask = (input_ids != pad_id)
|
bool_mask = input_ids != pad_id
|
||||||
self._seq_mask[: batch_size, : seq_len] = bool_mask
|
self._seq_mask[:batch_size, :seq_len] = bool_mask
|
||||||
|
|
||||||
def get_kvcache(self) -> Tuple[Tensor, Tensor]:
|
def get_kvcache(self) -> Tuple[Tensor, Tensor]:
|
||||||
return self._kv_cache
|
return self._kv_cache
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,11 @@ from khaosz.config.param_config import ModelParameter
|
||||||
|
|
||||||
HistoryType = List[Tuple[str, str]]
|
HistoryType = List[Tuple[str, str]]
|
||||||
|
|
||||||
|
|
||||||
def build_prompt(
|
def build_prompt(
|
||||||
query: str,
|
query: str,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
history: Optional[HistoryType] = None
|
history: Optional[HistoryType] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Build prompt in ChatML format for query and history.
|
Build prompt in ChatML format for query and history.
|
||||||
|
|
@ -79,6 +80,7 @@ class GenerationRequest:
|
||||||
system_prompt: System prompt for the conversation.
|
system_prompt: System prompt for the conversation.
|
||||||
stream: Whether to use streaming generation.
|
stream: Whether to use streaming generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
top_k: int
|
top_k: int
|
||||||
top_p: float
|
top_p: float
|
||||||
temperature: float
|
temperature: float
|
||||||
|
|
@ -146,9 +148,12 @@ class StreamGenerator(GeneratorCore):
|
||||||
|
|
||||||
for _ in range(len(ids), self.config.max_len):
|
for _ in range(len(ids), self.config.max_len):
|
||||||
next_token_id, cache_increase = self.generate_iterator(
|
next_token_id, cache_increase = self.generate_iterator(
|
||||||
input_ids, request.temperature, request.top_k, request.top_p,
|
input_ids,
|
||||||
|
request.temperature,
|
||||||
|
request.top_k,
|
||||||
|
request.top_p,
|
||||||
kv_caches=kv_caches,
|
kv_caches=kv_caches,
|
||||||
start_pos=cur_cache_pos
|
start_pos=cur_cache_pos,
|
||||||
)
|
)
|
||||||
|
|
||||||
input_ids = next_token_id
|
input_ids = next_token_id
|
||||||
|
|
@ -172,7 +177,10 @@ class BatchGenerator(GeneratorCore):
|
||||||
if request.history is None:
|
if request.history is None:
|
||||||
request.history = [[] for _ in range(batch_size)]
|
request.history = [[] for _ in range(batch_size)]
|
||||||
|
|
||||||
prompts = [build_prompt(query, history) for query, history in zip(request.query, request.history)]
|
prompts = [
|
||||||
|
build_prompt(query, history)
|
||||||
|
for query, history in zip(request.query, request.history)
|
||||||
|
]
|
||||||
|
|
||||||
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
|
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
|
||||||
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
|
ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id)
|
||||||
|
|
@ -189,13 +197,16 @@ class BatchGenerator(GeneratorCore):
|
||||||
|
|
||||||
while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0:
|
while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0:
|
||||||
kv_caches = cache_manager.get_kvcache()
|
kv_caches = cache_manager.get_kvcache()
|
||||||
attn_mask =cache_manager.get_seq_mask()
|
attn_mask = cache_manager.get_seq_mask()
|
||||||
|
|
||||||
next_token_id, cache_increase = self.generate_iterator(
|
next_token_id, cache_increase = self.generate_iterator(
|
||||||
input_tensor, request.temperature, request.top_k, request.top_p,
|
input_tensor,
|
||||||
|
request.temperature,
|
||||||
|
request.top_k,
|
||||||
|
request.top_p,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
kv_caches=kv_caches,
|
kv_caches=kv_caches,
|
||||||
start_pos=cur_cache_pos
|
start_pos=cur_cache_pos,
|
||||||
)
|
)
|
||||||
|
|
||||||
cur_cache_pos += cache_increase
|
cur_cache_pos += cache_increase
|
||||||
|
|
@ -248,7 +259,9 @@ class GeneratorFactory:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_generator(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
|
def create_generator(
|
||||||
|
parameter: ModelParameter, request: GenerationRequest
|
||||||
|
) -> GeneratorCore:
|
||||||
"""Create a generator based on request characteristics.
|
"""Create a generator based on request characteristics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -282,7 +295,9 @@ class GeneratorFactory:
|
||||||
return EmbeddingEncoder(parameter)
|
return EmbeddingEncoder(parameter)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
|
def create(
|
||||||
|
cls, parameter: ModelParameter, request: GenerationRequest
|
||||||
|
) -> GeneratorCore:
|
||||||
"""Convenience method that delegates to create_generator.
|
"""Convenience method that delegates to create_generator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -293,4 +308,3 @@ class GeneratorFactory:
|
||||||
Generator instance
|
Generator instance
|
||||||
"""
|
"""
|
||||||
return cls.create_generator(parameter, request)
|
return cls.create_generator(parameter, request)
|
||||||
|
|
||||||
|
|
@ -7,11 +7,4 @@ from khaosz.model.module import (
|
||||||
)
|
)
|
||||||
from khaosz.model.transformer import Transformer
|
from khaosz.model.transformer import Transformer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["Linear", "RMSNorm", "MLP", "GQA", "DecoderBlock", "Transformer"]
|
||||||
"Linear",
|
|
||||||
"RMSNorm",
|
|
||||||
"MLP",
|
|
||||||
"GQA",
|
|
||||||
"DecoderBlock",
|
|
||||||
"Transformer"
|
|
||||||
]
|
|
||||||
|
|
|
||||||
|
|
@ -25,11 +25,12 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
.reshape(bs, slen, n_heads * n_rep, head_dim)
|
.reshape(bs, slen, n_heads * n_rep, head_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_rotary_emb(
|
def get_rotary_emb(
|
||||||
dim: int,
|
dim: int,
|
||||||
max_len: int,
|
max_len: int,
|
||||||
base: float = 10000,
|
base: float = 10000,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Get the rotary embedding for the given dimension and maximum length.
|
Get the rotary embedding for the given dimension and maximum length.
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -46,6 +47,7 @@ def get_rotary_emb(
|
||||||
|
|
||||||
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
|
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Apply rotary embedding to the input tensor using cos/sin form.
|
Apply rotary embedding to the input tensor using cos/sin form.
|
||||||
|
|
@ -69,13 +71,13 @@ def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tens
|
||||||
x_imag_rot = x_real * sin + x_imag * cos
|
x_imag_rot = x_real * sin + x_imag * cos
|
||||||
|
|
||||||
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
|
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
|
||||||
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
|
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
|
||||||
|
|
||||||
return x_out.to(dtype)
|
return x_out.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
class RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, dim: int, max_len: int, base: int=10000):
|
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
|
|
@ -89,7 +91,7 @@ class RotaryEmbedding(nn.Module):
|
||||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||||
self.max_len_cached = max_len
|
self.max_len_cached = max_len
|
||||||
|
|
||||||
def forward(self, x: Tensor, start_pos: int=0) -> Tuple[Tensor, Tensor]:
|
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
|
||||||
seq_len = x.size(1)
|
seq_len = x.size(1)
|
||||||
|
|
||||||
if self.max_len_cached < seq_len + start_pos:
|
if self.max_len_cached < seq_len + start_pos:
|
||||||
|
|
@ -115,11 +117,11 @@ class RMSNorm(nn.Module):
|
||||||
def __init__(self, dim, norm_eps):
|
def __init__(self, dim, norm_eps):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
self.normalized_shape = (dim, )
|
self.normalized_shape = (dim,)
|
||||||
self.norm_eps = norm_eps
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps)
|
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps)
|
||||||
return rms.to(x.dtype)
|
return rms.to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -136,7 +138,6 @@ class MLP(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GQA(nn.Module):
|
class GQA(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -146,7 +147,7 @@ class GQA(nn.Module):
|
||||||
use_qk_norm: bool,
|
use_qk_norm: bool,
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int
|
layer_id: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert dim % n_heads == 0
|
assert dim % n_heads == 0
|
||||||
|
|
@ -184,7 +185,7 @@ class GQA(nn.Module):
|
||||||
rotary_emb: Tuple[Tensor, Tensor],
|
rotary_emb: Tuple[Tensor, Tensor],
|
||||||
mask: Tensor = None,
|
mask: Tensor = None,
|
||||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
start_pos: int = 0
|
start_pos: int = 0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
bsz, seq_len, _ = x.size()
|
bsz, seq_len, _ = x.size()
|
||||||
is_causal = mask is None
|
is_causal = mask is None
|
||||||
|
|
@ -202,19 +203,24 @@ class GQA(nn.Module):
|
||||||
k_cache, v_cache = kv_cache
|
k_cache, v_cache = kv_cache
|
||||||
|
|
||||||
# copy to cache
|
# copy to cache
|
||||||
k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k
|
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
|
||||||
v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v
|
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
|
||||||
|
|
||||||
# get cache
|
# get cache
|
||||||
k = k_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
||||||
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
||||||
|
|
||||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||||
|
|
||||||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
||||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||||
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
|
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
|
||||||
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal).permute(0, 2, 1, 3).contiguous().flatten(2)
|
sdqa_out = (
|
||||||
|
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.contiguous()
|
||||||
|
.flatten(2)
|
||||||
|
)
|
||||||
|
|
||||||
if self.use_gated_attention:
|
if self.use_gated_attention:
|
||||||
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
|
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
|
||||||
|
|
@ -235,7 +241,7 @@ class MLA(nn.Module):
|
||||||
qk_rope_head_dim: int,
|
qk_rope_head_dim: int,
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int
|
layer_id: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
@ -270,7 +276,7 @@ class MLA(nn.Module):
|
||||||
rotary_emb: Tuple[Tensor, Tensor],
|
rotary_emb: Tuple[Tensor, Tensor],
|
||||||
mask: Tensor = None,
|
mask: Tensor = None,
|
||||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
start_pos: int = 0
|
start_pos: int = 0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
bsz, seq_len, _ = x.size()
|
bsz, seq_len, _ = x.size()
|
||||||
is_causal = mask is None
|
is_causal = mask is None
|
||||||
|
|
@ -285,12 +291,13 @@ class MLA(nn.Module):
|
||||||
kv = kv.view(bsz, seq_len, self.n_kv_heads, -1)
|
kv = kv.view(bsz, seq_len, self.n_kv_heads, -1)
|
||||||
|
|
||||||
k_nope, k_rope, v = torch.split(
|
k_nope, k_rope, v = torch.split(
|
||||||
kv,
|
kv, [self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], dim=-1
|
||||||
[self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim],
|
|
||||||
dim=-1
|
|
||||||
)
|
)
|
||||||
|
|
||||||
q_nope, q_rope = q[..., :self.qk_nope_head_dim], q[..., self.qk_rope_head_dim:]
|
q_nope, q_rope = (
|
||||||
|
q[..., : self.qk_nope_head_dim],
|
||||||
|
q[..., self.qk_rope_head_dim :],
|
||||||
|
)
|
||||||
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
||||||
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
||||||
|
|
||||||
|
|
@ -299,10 +306,10 @@ class MLA(nn.Module):
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
k_cache, v_cache = kv_cache
|
k_cache, v_cache = kv_cache
|
||||||
k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k
|
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
|
||||||
v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v
|
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
|
||||||
k = k_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
||||||
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
||||||
|
|
||||||
q = q.permute(0, 2, 1, 3)
|
q = q.permute(0, 2, 1, 3)
|
||||||
k = k.permute(0, 2, 1, 3)
|
k = k.permute(0, 2, 1, 3)
|
||||||
|
|
@ -329,11 +336,18 @@ class DecoderBlock(nn.Module):
|
||||||
norm_eps: int,
|
norm_eps: int,
|
||||||
use_qk_norm: bool,
|
use_qk_norm: bool,
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int
|
layer_id: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attention = GQA(dim, n_heads, n_kv_heads,
|
self.attention = GQA(
|
||||||
use_qk_norm, norm_eps, use_gated_attention, layer_id)
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
use_qk_norm,
|
||||||
|
norm_eps,
|
||||||
|
use_gated_attention,
|
||||||
|
layer_id,
|
||||||
|
)
|
||||||
self.input_norm = RMSNorm(dim, norm_eps)
|
self.input_norm = RMSNorm(dim, norm_eps)
|
||||||
self.mlp = MLP(dim, dim_ffn)
|
self.mlp = MLP(dim, dim_ffn)
|
||||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||||
|
|
@ -344,15 +358,11 @@ class DecoderBlock(nn.Module):
|
||||||
rotary_emb: Tuple[Tensor, Tensor],
|
rotary_emb: Tuple[Tensor, Tensor],
|
||||||
attention_mask: Optional[Tensor] = None,
|
attention_mask: Optional[Tensor] = None,
|
||||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
start_pos: int = 0
|
start_pos: int = 0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
# attention
|
# attention
|
||||||
attn_output = self.attention(
|
attn_output = self.attention(
|
||||||
self.input_norm(x),
|
self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos
|
||||||
rotary_emb,
|
|
||||||
attention_mask,
|
|
||||||
kv_cache,
|
|
||||||
start_pos
|
|
||||||
)
|
)
|
||||||
x = attn_output + x
|
x = attn_output + x
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,15 +4,21 @@ import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Any, Mapping, Optional, Tuple
|
from typing import Any, Mapping, Optional, Tuple
|
||||||
from khaosz.config.model_config import ModelConfig
|
from khaosz.config.model_config import ModelConfig
|
||||||
from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, RotaryEmbedding
|
from khaosz.model.module import (
|
||||||
|
Embedding,
|
||||||
|
DecoderBlock,
|
||||||
|
Linear,
|
||||||
|
RMSNorm,
|
||||||
|
RotaryEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def process_attention_mask(
|
def process_attention_mask(
|
||||||
seq_mask: Tensor,
|
seq_mask: Tensor,
|
||||||
input_tensor: Tensor,
|
input_tensor: Tensor,
|
||||||
start_pos: int = 0,
|
start_pos: int = 0,
|
||||||
is_causal: bool = False,
|
is_causal: bool = False,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Create attention mask for GQA
|
Create attention mask for GQA
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -40,16 +46,20 @@ def process_attention_mask(
|
||||||
return seq_mask
|
return seq_mask
|
||||||
|
|
||||||
batch_size = seq_mask.size(0)
|
batch_size = seq_mask.size(0)
|
||||||
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
|
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
|
||||||
# (bsz, start_pos + seq_len)
|
# (bsz, start_pos + seq_len)
|
||||||
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
|
expanded_mask = seq_mask.unsqueeze(1).expand(
|
||||||
|
batch_size, seq_len, start_pos + seq_len
|
||||||
|
)
|
||||||
# (bsz, seq_len, start_pos + seq_len)
|
# (bsz, seq_len, start_pos + seq_len)
|
||||||
|
|
||||||
if is_causal:
|
if is_causal:
|
||||||
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
|
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
|
||||||
|
|
||||||
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
|
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
|
||||||
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
|
attention_mask = attention_mask.masked_fill_(
|
||||||
|
~expanded_mask, -torch.finfo(dtype).max / 2
|
||||||
|
).unsqueeze(1)
|
||||||
# (bsz, 1, seq_len, seq_len + start_pos)
|
# (bsz, 1, seq_len, seq_len + start_pos)
|
||||||
|
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
@ -59,14 +69,26 @@ class Transformer(nn.Module):
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: ModelConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rotary_embeding = RotaryEmbedding(config.dim // config.n_heads, config.max_len)
|
self.rotary_embeding = RotaryEmbedding(
|
||||||
|
config.dim // config.n_heads, config.max_len
|
||||||
|
)
|
||||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||||
|
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList(
|
||||||
DecoderBlock(config.dim, config.n_heads, config.dim_ffn, config.n_kv_heads,
|
[
|
||||||
config.norm_eps, config.use_qk_norm, config.use_gated_attention, layer_id)
|
DecoderBlock(
|
||||||
for layer_id in range(config.n_layers)
|
config.dim,
|
||||||
])
|
config.n_heads,
|
||||||
|
config.dim_ffn,
|
||||||
|
config.n_kv_heads,
|
||||||
|
config.norm_eps,
|
||||||
|
config.use_qk_norm,
|
||||||
|
config.use_gated_attention,
|
||||||
|
layer_id,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.n_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||||
self.lm_head = Linear(config.dim, config.vocab_size)
|
self.lm_head = Linear(config.dim, config.vocab_size)
|
||||||
|
|
@ -77,8 +99,8 @@ class Transformer(nn.Module):
|
||||||
self._init_parameters()
|
self._init_parameters()
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||||
lm_head_key = 'lm_head.weight'
|
lm_head_key = "lm_head.weight"
|
||||||
embed_key = 'embed_tokens.weight'
|
embed_key = "embed_tokens.weight"
|
||||||
|
|
||||||
if self.config.tie_weight == True:
|
if self.config.tie_weight == True:
|
||||||
# same tensor
|
# same tensor
|
||||||
|
|
@ -90,11 +112,13 @@ class Transformer(nn.Module):
|
||||||
|
|
||||||
return super().load_state_dict(state_dict, strict, assign)
|
return super().load_state_dict(state_dict, strict, assign)
|
||||||
|
|
||||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||||
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
state_dict = super().state_dict(
|
||||||
|
destination=destination, prefix=prefix, keep_vars=keep_vars
|
||||||
|
)
|
||||||
|
|
||||||
if self.config.tie_weight == True:
|
if self.config.tie_weight == True:
|
||||||
lm_head_key = prefix + 'lm_head.weight'
|
lm_head_key = prefix + "lm_head.weight"
|
||||||
if lm_head_key in state_dict:
|
if lm_head_key in state_dict:
|
||||||
del state_dict[lm_head_key]
|
del state_dict[lm_head_key]
|
||||||
|
|
||||||
|
|
@ -108,18 +132,16 @@ class Transformer(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
input_mask: Optional[Tensor]=None,
|
input_mask: Optional[Tensor] = None,
|
||||||
persistent_key_values: Optional[Tuple[Tensor, Tensor]]=None,
|
persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
start_pos: int = 0
|
start_pos: int = 0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
assert input_ids.ndim == 2
|
assert input_ids.ndim == 2
|
||||||
|
|
||||||
x = self.embed_tokens(input_ids)
|
x = self.embed_tokens(input_ids)
|
||||||
rotary_emb = self.rotary_embeding(x, start_pos)
|
rotary_emb = self.rotary_embeding(x, start_pos)
|
||||||
|
|
||||||
attn_mask = process_attention_mask(
|
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
|
||||||
input_mask, x, start_pos, is_causal=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos)
|
x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos)
|
||||||
|
|
@ -127,8 +149,4 @@ class Transformer(nn.Module):
|
||||||
hidden_states = self.norm(x)
|
hidden_states = self.norm(x)
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
return {
|
return {"logits": logits, "hidden_states": hidden_states}
|
||||||
"logits": logits,
|
|
||||||
"hidden_states": hidden_states
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,26 +2,20 @@ from khaosz.parallel.setup import (
|
||||||
get_world_size,
|
get_world_size,
|
||||||
get_rank,
|
get_rank,
|
||||||
get_current_device,
|
get_current_device,
|
||||||
|
|
||||||
only_on_rank,
|
only_on_rank,
|
||||||
setup_parallel,
|
setup_parallel,
|
||||||
spawn_parallel_fn
|
spawn_parallel_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khaosz.parallel.module import (
|
from khaosz.parallel.module import RowParallelLinear, ColumnParallelLinear
|
||||||
RowParallelLinear,
|
|
||||||
ColumnParallelLinear
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_world_size",
|
"get_world_size",
|
||||||
"get_rank",
|
"get_rank",
|
||||||
"get_current_device",
|
"get_current_device",
|
||||||
|
|
||||||
"only_on_rank",
|
"only_on_rank",
|
||||||
"setup_parallel",
|
"setup_parallel",
|
||||||
"spawn_parallel_fn",
|
"spawn_parallel_fn",
|
||||||
|
|
||||||
"RowParallelLinear",
|
"RowParallelLinear",
|
||||||
"ColumnParallelLinear"
|
"ColumnParallelLinear",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ class RowParallelLinear(ParallelModel):
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
reduce_results: bool = True
|
reduce_results: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(process_group)
|
super().__init__(process_group)
|
||||||
|
|
||||||
|
|
@ -32,7 +32,9 @@ class RowParallelLinear(ParallelModel):
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
|
|
||||||
if in_features % self.world_size != 0:
|
if in_features % self.world_size != 0:
|
||||||
raise ValueError(f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}")
|
raise ValueError(
|
||||||
|
f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}"
|
||||||
|
)
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank))
|
self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank))
|
||||||
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
|
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
|
||||||
|
|
@ -49,8 +51,8 @@ class RowParallelLinear(ParallelModel):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
||||||
full_weight = state_dict.get('weight')
|
full_weight = state_dict.get("weight")
|
||||||
full_bias = state_dict.get('bias')
|
full_bias = state_dict.get("bias")
|
||||||
|
|
||||||
start_idx = self.rank * self.in_features_per_rank
|
start_idx = self.rank * self.in_features_per_rank
|
||||||
end_idx = start_idx + self.in_features_per_rank
|
end_idx = start_idx + self.in_features_per_rank
|
||||||
|
|
@ -68,7 +70,7 @@ class ColumnParallelLinear(ParallelModel):
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
gather_results: bool = True
|
gather_results: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(process_group)
|
super().__init__(process_group)
|
||||||
|
|
||||||
|
|
@ -78,10 +80,16 @@ class ColumnParallelLinear(ParallelModel):
|
||||||
self.gather_results = gather_results
|
self.gather_results = gather_results
|
||||||
|
|
||||||
if out_features % self.world_size != 0:
|
if out_features % self.world_size != 0:
|
||||||
raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}")
|
raise ValueError(
|
||||||
|
f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}"
|
||||||
|
)
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.empty(self.out_features_per_rank, self.in_features))
|
self.weight = nn.Parameter(
|
||||||
self.bias = nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None
|
torch.empty(self.out_features_per_rank, self.in_features)
|
||||||
|
)
|
||||||
|
self.bias = (
|
||||||
|
nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
output = F.linear(input, self.weight, self.bias)
|
output = F.linear(input, self.weight, self.bias)
|
||||||
|
|
@ -94,8 +102,8 @@ class ColumnParallelLinear(ParallelModel):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
||||||
full_weight = state_dict.get('weight')
|
full_weight = state_dict.get("weight")
|
||||||
full_bias = state_dict.get('bias')
|
full_bias = state_dict.get("bias")
|
||||||
|
|
||||||
start_idx = self.rank * self.out_features_per_rank
|
start_idx = self.rank * self.out_features_per_rank
|
||||||
end_idx = start_idx + self.out_features_per_rank
|
end_idx = start_idx + self.out_features_per_rank
|
||||||
|
|
|
||||||
|
|
@ -11,18 +11,21 @@ from typing import Callable, List, Optional
|
||||||
def get_current_device():
|
def get_current_device():
|
||||||
return os.environ["LOCAL_DEVICE"]
|
return os.environ["LOCAL_DEVICE"]
|
||||||
|
|
||||||
|
|
||||||
def get_world_size() -> int:
|
def get_world_size() -> int:
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
return dist.get_world_size()
|
return dist.get_world_size()
|
||||||
else:
|
else:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
def get_rank() -> int:
|
def get_rank() -> int:
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
return dist.get_rank()
|
return dist.get_rank()
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def setup_parallel(
|
def setup_parallel(
|
||||||
rank: int,
|
rank: int,
|
||||||
|
|
@ -31,7 +34,7 @@ def setup_parallel(
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500",
|
master_port: str = "29500",
|
||||||
device_type: str = "cuda",
|
device_type: str = "cuda",
|
||||||
device_ids: Optional[List[int]] = None
|
device_ids: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
|
@ -48,24 +51,21 @@ def setup_parallel(
|
||||||
rank = device_ids[rank % len(device_ids)]
|
rank = device_ids[rank % len(device_ids)]
|
||||||
device_id = torch.device(device_type, device_ids[rank])
|
device_id = torch.device(device_type, device_ids[rank])
|
||||||
|
|
||||||
os.environ['MASTER_ADDR'] = master_addr
|
os.environ["MASTER_ADDR"] = master_addr
|
||||||
os.environ['MASTER_PORT'] = master_port
|
os.environ["MASTER_PORT"] = master_port
|
||||||
|
|
||||||
os.environ['LOCAL_RANK'] = str(rank)
|
os.environ["LOCAL_RANK"] = str(rank)
|
||||||
os.environ['WORLD_SIZE'] = str(world_size)
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
|
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
rank=rank,
|
rank=rank, world_size=world_size, backend=backend, device_id=device_id
|
||||||
world_size=world_size,
|
|
||||||
backend=backend,
|
|
||||||
device_id=device_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if backend == "nccl" and torch.cuda.is_available():
|
if backend == "nccl" and torch.cuda.is_available():
|
||||||
torch.cuda.set_device(device_id)
|
torch.cuda.set_device(device_id)
|
||||||
elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available():
|
elif backend == "ccl" and hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
torch.xpu.set_device(device_id)
|
torch.xpu.set_device(device_id)
|
||||||
|
|
||||||
yield dist.group.WORLD
|
yield dist.group.WORLD
|
||||||
|
|
@ -73,6 +73,7 @@ def setup_parallel(
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def only_on_rank(rank, sync=False):
|
def only_on_rank(rank, sync=False):
|
||||||
"""
|
"""
|
||||||
decorator to run a function only on a specific rank.
|
decorator to run a function only on a specific rank.
|
||||||
|
|
@ -94,6 +95,7 @@ def only_on_rank(rank, sync=False):
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def wrapper_spawn_func(
|
def wrapper_spawn_func(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
|
|
@ -103,7 +105,7 @@ def wrapper_spawn_func(
|
||||||
device_type: str,
|
device_type: str,
|
||||||
device_ids: List[int],
|
device_ids: List[int],
|
||||||
func: Callable,
|
func: Callable,
|
||||||
kwargs: dict
|
kwargs: dict,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
with setup_parallel(
|
with setup_parallel(
|
||||||
|
|
@ -113,7 +115,7 @@ def wrapper_spawn_func(
|
||||||
master_addr=master_addr,
|
master_addr=master_addr,
|
||||||
master_port=master_port,
|
master_port=master_port,
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
device_ids=device_ids
|
device_ids=device_ids,
|
||||||
):
|
):
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
|
|
||||||
|
|
@ -121,6 +123,7 @@ def wrapper_spawn_func(
|
||||||
print(f"Error in rank {rank}: {e}")
|
print(f"Error in rank {rank}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def spawn_parallel_fn(
|
def spawn_parallel_fn(
|
||||||
func: Callable,
|
func: Callable,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
|
|
@ -129,10 +132,17 @@ def spawn_parallel_fn(
|
||||||
master_port: str = "29500",
|
master_port: str = "29500",
|
||||||
device_type: str = "cuda",
|
device_type: str = "cuda",
|
||||||
device_ids: Optional[List[int]] = None,
|
device_ids: Optional[List[int]] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
# clear environment variables
|
# clear environment variables
|
||||||
for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK', 'LOCAL_DEVICE']:
|
for key in [
|
||||||
|
"MASTER_ADDR",
|
||||||
|
"MASTER_PORT",
|
||||||
|
"RANK",
|
||||||
|
"WORLD_SIZE",
|
||||||
|
"LOCAL_RANK",
|
||||||
|
"LOCAL_DEVICE",
|
||||||
|
]:
|
||||||
if key in os.environ:
|
if key in os.environ:
|
||||||
del os.environ[key]
|
del os.environ[key]
|
||||||
|
|
||||||
|
|
@ -144,12 +154,17 @@ def spawn_parallel_fn(
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
return
|
return
|
||||||
|
|
||||||
wrapper_spawn_func_args = (world_size, backend, master_addr, master_port,
|
wrapper_spawn_func_args = (
|
||||||
device_type, device_ids, func, kwargs)
|
world_size,
|
||||||
|
backend,
|
||||||
|
master_addr,
|
||||||
|
master_port,
|
||||||
|
device_type,
|
||||||
|
device_ids,
|
||||||
|
func,
|
||||||
|
kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
mp.spawn(
|
mp.spawn(
|
||||||
wrapper_spawn_func,
|
wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True
|
||||||
nprocs=world_size,
|
|
||||||
args=wrapper_spawn_func_args,
|
|
||||||
join=True
|
|
||||||
)
|
)
|
||||||
|
|
@ -14,15 +14,12 @@ from khaosz.trainer.train_callback import (
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Main trainer
|
# Main trainer
|
||||||
"Trainer",
|
"Trainer",
|
||||||
|
|
||||||
# Strategy factory
|
# Strategy factory
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
"BaseStrategy",
|
"BaseStrategy",
|
||||||
|
|
||||||
# Scheduler factory
|
# Scheduler factory
|
||||||
"SchedulerFactory",
|
"SchedulerFactory",
|
||||||
"BaseScheduler",
|
"BaseScheduler",
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
"TrainCallback",
|
"TrainCallback",
|
||||||
"GradientClippingCallback",
|
"GradientClippingCallback",
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
||||||
""" Compute gradient norm for each parameter in the model. """
|
"""Compute gradient norm for each parameter in the model."""
|
||||||
norms = {}
|
norms = {}
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
norms[name] = 0.0
|
norms[name] = 0.0
|
||||||
|
|
@ -11,8 +12,9 @@ def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
||||||
norms[name] = norm
|
norms[name] = norm
|
||||||
return norms
|
return norms
|
||||||
|
|
||||||
|
|
||||||
def grad_std(model: nn.Module) -> Dict[str, float]:
|
def grad_std(model: nn.Module) -> Dict[str, float]:
|
||||||
""" Compute standard deviation of gradients for each parameter. """
|
"""Compute standard deviation of gradients for each parameter."""
|
||||||
stds = {}
|
stds = {}
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
stds[name] = 0.0
|
stds[name] = 0.0
|
||||||
|
|
@ -21,30 +23,33 @@ def grad_std(model: nn.Module) -> Dict[str, float]:
|
||||||
stds[name] = std
|
stds[name] = std
|
||||||
return stds
|
return stds
|
||||||
|
|
||||||
|
|
||||||
def grad_max(model: nn.Module) -> Dict[str, float]:
|
def grad_max(model: nn.Module) -> Dict[str, float]:
|
||||||
""" Find the maximum absolute gradient value for each parameter. """
|
"""Find the maximum absolute gradient value for each parameter."""
|
||||||
max_vals = {}
|
max_vals = {}
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
max_vals[name] = -float('inf')
|
max_vals[name] = -float("inf")
|
||||||
if param.grad:
|
if param.grad:
|
||||||
max_val = param.grad.data.max().item()
|
max_val = param.grad.data.max().item()
|
||||||
max_vals[name] = max_val
|
max_vals[name] = max_val
|
||||||
|
|
||||||
return max_vals
|
return max_vals
|
||||||
|
|
||||||
|
|
||||||
def grad_min(model: nn.Module) -> Dict[str, float]:
|
def grad_min(model: nn.Module) -> Dict[str, float]:
|
||||||
""" Find the minimum absolute gradient value for each parameter. """
|
"""Find the minimum absolute gradient value for each parameter."""
|
||||||
min_vals = {}
|
min_vals = {}
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
min_vals[name] = float('inf')
|
min_vals[name] = float("inf")
|
||||||
if param.grad:
|
if param.grad:
|
||||||
min_val = param.grad.data.min().item()
|
min_val = param.grad.data.min().item()
|
||||||
min_vals[name] = min_val
|
min_vals[name] = min_val
|
||||||
|
|
||||||
return min_vals
|
return min_vals
|
||||||
|
|
||||||
|
|
||||||
def grad_mean(model: nn.Module) -> Dict[str, float]:
|
def grad_mean(model: nn.Module) -> Dict[str, float]:
|
||||||
""" Compute mean of gradients for each parameter. """
|
"""Compute mean of gradients for each parameter."""
|
||||||
means = {}
|
means = {}
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
means[name] = 0.0
|
means[name] = 0.0
|
||||||
|
|
@ -54,8 +59,9 @@ def grad_mean(model: nn.Module) -> Dict[str, float]:
|
||||||
|
|
||||||
return means
|
return means
|
||||||
|
|
||||||
|
|
||||||
def grad_nan_num(model: nn.Module) -> Dict[str, int]:
|
def grad_nan_num(model: nn.Module) -> Dict[str, int]:
|
||||||
""" Count the number of NaNs in gradients for each parameter. """
|
"""Count the number of NaNs in gradients for each parameter."""
|
||||||
nan_nums = {}
|
nan_nums = {}
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
nan_nums[name] = 0
|
nan_nums[name] = 0
|
||||||
|
|
@ -64,26 +70,34 @@ def grad_nan_num(model: nn.Module) -> Dict[str, int]:
|
||||||
nan_nums[name] = nan_num
|
nan_nums[name] = nan_num
|
||||||
return nan_nums
|
return nan_nums
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_loss(ctx):
|
def ctx_get_loss(ctx):
|
||||||
return ctx.loss
|
return ctx.loss
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_lr(ctx):
|
def ctx_get_lr(ctx):
|
||||||
return ctx.optimizer.param_groups[-1]['lr']
|
return ctx.optimizer.param_groups[-1]["lr"]
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_norm(ctx):
|
def ctx_get_grad_norm(ctx):
|
||||||
return grad_norm(ctx.model)
|
return grad_norm(ctx.model)
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_std(ctx):
|
def ctx_get_grad_std(ctx):
|
||||||
return grad_std(ctx.model)
|
return grad_std(ctx.model)
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_max(ctx):
|
def ctx_get_grad_max(ctx):
|
||||||
return grad_max(ctx.model)
|
return grad_max(ctx.model)
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_min(ctx):
|
def ctx_get_grad_min(ctx):
|
||||||
return grad_min(ctx.model)
|
return grad_min(ctx.model)
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_mean(ctx):
|
def ctx_get_grad_mean(ctx):
|
||||||
return grad_mean(ctx.model)
|
return grad_mean(ctx.model)
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_nan_num(ctx):
|
def ctx_get_grad_nan_num(ctx):
|
||||||
return grad_nan_num(ctx.model)
|
return grad_nan_num(ctx.model)
|
||||||
|
|
@ -55,11 +55,15 @@ class SchedulerFactory:
|
||||||
Returns:
|
Returns:
|
||||||
Decorator function that registers the scheduler class
|
Decorator function that registers the scheduler class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
|
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
|
||||||
if not issubclass(scheduler_cls, BaseScheduler):
|
if not issubclass(scheduler_cls, BaseScheduler):
|
||||||
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
raise TypeError(
|
||||||
|
f"{scheduler_cls.__name__} must inherit from BaseScheduler"
|
||||||
|
)
|
||||||
cls.SCHEDULER_MAP[name] = scheduler_cls
|
cls.SCHEDULER_MAP[name] = scheduler_cls
|
||||||
return scheduler_cls
|
return scheduler_cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -121,7 +125,7 @@ class CosineScheduler(BaseScheduler):
|
||||||
warmup_steps: int,
|
warmup_steps: int,
|
||||||
lr_decay_steps: int,
|
lr_decay_steps: int,
|
||||||
min_rate: float = 0.05,
|
min_rate: float = 0.05,
|
||||||
last_epoch: int = -1
|
last_epoch: int = -1,
|
||||||
):
|
):
|
||||||
self.warmup_steps = warmup_steps
|
self.warmup_steps = warmup_steps
|
||||||
self.lr_decay_steps = lr_decay_steps
|
self.lr_decay_steps = lr_decay_steps
|
||||||
|
|
@ -129,7 +133,6 @@ class CosineScheduler(BaseScheduler):
|
||||||
self.total_steps = warmup_steps + lr_decay_steps
|
self.total_steps = warmup_steps + lr_decay_steps
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
|
||||||
def get_lr(self) -> List[float]:
|
def get_lr(self) -> List[float]:
|
||||||
# warmup
|
# warmup
|
||||||
if self.last_epoch < self.warmup_steps:
|
if self.last_epoch < self.warmup_steps:
|
||||||
|
|
@ -145,19 +148,21 @@ class CosineScheduler(BaseScheduler):
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
state = super().state_dict()
|
state = super().state_dict()
|
||||||
state.update({
|
state.update(
|
||||||
'warmup_steps': self.warmup_steps,
|
{
|
||||||
'lr_decay_steps': self.lr_decay_steps,
|
"warmup_steps": self.warmup_steps,
|
||||||
'min_rate': self.min_rate,
|
"lr_decay_steps": self.lr_decay_steps,
|
||||||
'total_steps': self.total_steps,
|
"min_rate": self.min_rate,
|
||||||
})
|
"total_steps": self.total_steps,
|
||||||
|
}
|
||||||
|
)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
self.warmup_steps = state_dict.pop('warmup_steps')
|
self.warmup_steps = state_dict.pop("warmup_steps")
|
||||||
self.lr_decay_steps = state_dict.pop('lr_decay_steps')
|
self.lr_decay_steps = state_dict.pop("lr_decay_steps")
|
||||||
self.min_rate = state_dict.pop('min_rate')
|
self.min_rate = state_dict.pop("min_rate")
|
||||||
self.total_steps = state_dict.pop('total_steps')
|
self.total_steps = state_dict.pop("total_steps")
|
||||||
super().load_state_dict(state_dict)
|
super().load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -181,7 +186,6 @@ class SGDRScheduler(BaseScheduler):
|
||||||
|
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
# warmup
|
# warmup
|
||||||
if self.last_epoch < self.warmup_steps:
|
if self.last_epoch < self.warmup_steps:
|
||||||
|
|
@ -204,7 +208,9 @@ class SGDRScheduler(BaseScheduler):
|
||||||
steps_in_cycle = steps_since_warmup - total_cycles_length
|
steps_in_cycle = steps_since_warmup - total_cycles_length
|
||||||
|
|
||||||
# 2. Cosine annealing within the current cycle
|
# 2. Cosine annealing within the current cycle
|
||||||
cosine_factor = 0.5 * (1 + math.cos(math.pi * steps_in_cycle / current_cycle_length))
|
cosine_factor = 0.5 * (
|
||||||
|
1 + math.cos(math.pi * steps_in_cycle / current_cycle_length)
|
||||||
|
)
|
||||||
learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor
|
learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor
|
||||||
|
|
||||||
return [base_lr * learning_rate_factor for base_lr in self.base_lrs]
|
return [base_lr * learning_rate_factor for base_lr in self.base_lrs]
|
||||||
|
|
@ -212,18 +218,20 @@ class SGDRScheduler(BaseScheduler):
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
"""Returns the state of the scheduler as a dict."""
|
"""Returns the state of the scheduler as a dict."""
|
||||||
state = super().state_dict()
|
state = super().state_dict()
|
||||||
state.update({
|
state.update(
|
||||||
'warmup_steps': self.warmup_steps,
|
{
|
||||||
'cycle_length': self.cycle_length,
|
"warmup_steps": self.warmup_steps,
|
||||||
'min_rate': self.min_rate,
|
"cycle_length": self.cycle_length,
|
||||||
't_mult': self.t_mult
|
"min_rate": self.min_rate,
|
||||||
})
|
"t_mult": self.t_mult,
|
||||||
|
}
|
||||||
|
)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
"""Loads the scheduler's state."""
|
"""Loads the scheduler's state."""
|
||||||
self.warmup_steps = state_dict.pop('warmup_steps')
|
self.warmup_steps = state_dict.pop("warmup_steps")
|
||||||
self.cycle_length = state_dict.pop('cycle_length')
|
self.cycle_length = state_dict.pop("cycle_length")
|
||||||
self.min_rate = state_dict.pop('min_rate')
|
self.min_rate = state_dict.pop("min_rate")
|
||||||
self.t_mult = state_dict.pop('t_mult')
|
self.t_mult = state_dict.pop("t_mult")
|
||||||
super().load_state_dict(state_dict)
|
super().load_state_dict(state_dict)
|
||||||
|
|
@ -55,7 +55,9 @@ def get_logprobs(
|
||||||
"""
|
"""
|
||||||
allowed_reductions = ["mean", "sum", "none"]
|
allowed_reductions = ["mean", "sum", "none"]
|
||||||
if reduction not in allowed_reductions:
|
if reduction not in allowed_reductions:
|
||||||
raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'")
|
raise ValueError(
|
||||||
|
f"reduction must be one of {allowed_reductions}, got '{reduction}'"
|
||||||
|
)
|
||||||
|
|
||||||
shifted_input_ids = input_ids[:, 1:]
|
shifted_input_ids = input_ids[:, 1:]
|
||||||
shifted_mask = mask[:, 1:]
|
shifted_mask = mask[:, 1:]
|
||||||
|
|
@ -64,13 +66,13 @@ def get_logprobs(
|
||||||
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||||||
|
|
||||||
token_logprobs = torch.gather(
|
token_logprobs = torch.gather(
|
||||||
log_probs,
|
log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
|
||||||
dim=-1,
|
|
||||||
index=shifted_input_ids.unsqueeze(-1)
|
|
||||||
).squeeze(-1)
|
).squeeze(-1)
|
||||||
|
|
||||||
if reduction == "mean":
|
if reduction == "mean":
|
||||||
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(dim=-1).clamp(min=1.0)
|
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(
|
||||||
|
dim=-1
|
||||||
|
).clamp(min=1.0)
|
||||||
elif reduction == "sum":
|
elif reduction == "sum":
|
||||||
return (token_logprobs * shifted_mask).sum(dim=-1)
|
return (token_logprobs * shifted_mask).sum(dim=-1)
|
||||||
else:
|
else:
|
||||||
|
|
@ -80,7 +82,9 @@ def get_logprobs(
|
||||||
class BaseStrategy(ABC):
|
class BaseStrategy(ABC):
|
||||||
"""Abstract base class for training strategies."""
|
"""Abstract base class for training strategies."""
|
||||||
|
|
||||||
def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str):
|
def __init__(
|
||||||
|
self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str
|
||||||
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
|
@ -128,11 +132,15 @@ class StrategyFactory:
|
||||||
Returns:
|
Returns:
|
||||||
Decorator function that registers the strategy class
|
Decorator function that registers the strategy class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(strategy_cls: type) -> type:
|
def decorator(strategy_cls: type) -> type:
|
||||||
if not issubclass(strategy_cls, BaseStrategy):
|
if not issubclass(strategy_cls, BaseStrategy):
|
||||||
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
raise TypeError(
|
||||||
|
f"{strategy_cls.__name__} must inherit from BaseStrategy"
|
||||||
|
)
|
||||||
cls.STRATEGY_MAP[name] = strategy_cls
|
cls.STRATEGY_MAP[name] = strategy_cls
|
||||||
return strategy_cls
|
return strategy_cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -195,7 +203,7 @@ class SEQStrategy(BaseStrategy):
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
input=logits.flatten(0, 1).float(),
|
input=logits.flatten(0, 1).float(),
|
||||||
target=target_ids.flatten(),
|
target=target_ids.flatten(),
|
||||||
label_smoothing=self.label_smoothing
|
label_smoothing=self.label_smoothing,
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
@ -214,7 +222,11 @@ class SFTStrategy(BaseStrategy):
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
batch = move_to_device(batch, self.device)
|
batch = move_to_device(batch, self.device)
|
||||||
input_ids, target_ids, loss_mask = batch["input_ids"], batch["target_ids"], batch["loss_mask"]
|
input_ids, target_ids, loss_mask = (
|
||||||
|
batch["input_ids"],
|
||||||
|
batch["target_ids"],
|
||||||
|
batch["loss_mask"],
|
||||||
|
)
|
||||||
|
|
||||||
ignore_index = -100
|
ignore_index = -100
|
||||||
logits = self.model(input_ids=input_ids)["logits"]
|
logits = self.model(input_ids=input_ids)["logits"]
|
||||||
|
|
@ -224,7 +236,7 @@ class SFTStrategy(BaseStrategy):
|
||||||
input=logits.flatten(0, 1).float(),
|
input=logits.flatten(0, 1).float(),
|
||||||
target=target_ids.flatten(),
|
target=target_ids.flatten(),
|
||||||
ignore_index=ignore_index,
|
ignore_index=ignore_index,
|
||||||
label_smoothing=self.label_smoothing
|
label_smoothing=self.label_smoothing,
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
@ -239,12 +251,12 @@ class DPOStrategy(BaseStrategy):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
device: str,
|
device: str,
|
||||||
beta: float = 0.1,
|
beta: float = 0.1,
|
||||||
reduction: str = "mean",
|
reduction: str = "mean",
|
||||||
):
|
):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device)
|
||||||
self.ref_model = create_ref_model(model)
|
self.ref_model = create_ref_model(model)
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
|
|
@ -261,12 +273,14 @@ class DPOStrategy(BaseStrategy):
|
||||||
log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction)
|
log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
log_ref = get_logprobs(self.ref_model, contact_ids, contact_mask, self.reduction)
|
log_ref = get_logprobs(
|
||||||
|
self.ref_model, contact_ids, contact_mask, self.reduction
|
||||||
|
)
|
||||||
|
|
||||||
log_pi_chosen = log_pi[:chosen_ids.shape[0]]
|
log_pi_chosen = log_pi[: chosen_ids.shape[0]]
|
||||||
log_pi_rejected = log_pi[chosen_ids.shape[0]:]
|
log_pi_rejected = log_pi[chosen_ids.shape[0] :]
|
||||||
log_ref_chosen = log_ref[:chosen_ids.shape[0]]
|
log_ref_chosen = log_ref[: chosen_ids.shape[0]]
|
||||||
log_ref_rejected = log_ref[chosen_ids.shape[0]:]
|
log_ref_rejected = log_ref[chosen_ids.shape[0] :]
|
||||||
|
|
||||||
pi_log_ratio = log_pi_chosen - log_pi_rejected
|
pi_log_ratio = log_pi_chosen - log_pi_rejected
|
||||||
ref_log_ratio = log_ref_chosen - log_ref_rejected
|
ref_log_ratio = log_ref_chosen - log_ref_rejected
|
||||||
|
|
@ -316,11 +330,15 @@ class GRPOStrategy(BaseStrategy):
|
||||||
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
|
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
|
||||||
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
|
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
|
||||||
|
|
||||||
log_probs_policy = get_logprobs(self.model, full_sequences, full_masks, self.reduction)
|
log_probs_policy = get_logprobs(
|
||||||
|
self.model, full_sequences, full_masks, self.reduction
|
||||||
|
)
|
||||||
log_probs_policy = log_probs_policy.view(batch_size, group_size)
|
log_probs_policy = log_probs_policy.view(batch_size, group_size)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction)
|
log_probs_ref = get_logprobs(
|
||||||
|
self.ref_model, full_sequences, full_masks, self.reduction
|
||||||
|
)
|
||||||
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
||||||
|
|
||||||
# Compute advantages from rewards with normalization
|
# Compute advantages from rewards with normalization
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ from khaosz.trainer.metric_util import (
|
||||||
ctx_get_grad_norm,
|
ctx_get_grad_norm,
|
||||||
ctx_get_grad_mean,
|
ctx_get_grad_mean,
|
||||||
ctx_get_grad_std,
|
ctx_get_grad_std,
|
||||||
ctx_get_grad_nan_num
|
ctx_get_grad_nan_num,
|
||||||
)
|
)
|
||||||
from khaosz.data.serialization import Checkpoint
|
from khaosz.data.serialization import Checkpoint
|
||||||
from khaosz.trainer.train_context import TrainContext
|
from khaosz.trainer.train_context import TrainContext
|
||||||
|
|
@ -30,37 +30,38 @@ class TrainCallback(Protocol):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def on_train_begin(self, context: TrainContext):
|
def on_train_begin(self, context: TrainContext):
|
||||||
""" Called at the beginning of training. """
|
"""Called at the beginning of training."""
|
||||||
|
|
||||||
def on_train_end(self, context: TrainContext):
|
def on_train_end(self, context: TrainContext):
|
||||||
""" Called at the end of training. """
|
"""Called at the end of training."""
|
||||||
|
|
||||||
def on_epoch_begin(self, context: TrainContext):
|
def on_epoch_begin(self, context: TrainContext):
|
||||||
""" Called at the beginning of each epoch. """
|
"""Called at the beginning of each epoch."""
|
||||||
|
|
||||||
def on_epoch_end(self, context: TrainContext):
|
def on_epoch_end(self, context: TrainContext):
|
||||||
""" Called at the end of each epoch. """
|
"""Called at the end of each epoch."""
|
||||||
|
|
||||||
def on_step_begin(self, context: TrainContext):
|
def on_step_begin(self, context: TrainContext):
|
||||||
""" Called at the beginning of each step. """
|
"""Called at the beginning of each step."""
|
||||||
|
|
||||||
def on_step_end(self, context: TrainContext):
|
def on_step_end(self, context: TrainContext):
|
||||||
""" Called at the end of each step."""
|
"""Called at the end of each step."""
|
||||||
|
|
||||||
def on_batch_begin(self, context: TrainContext):
|
def on_batch_begin(self, context: TrainContext):
|
||||||
""" Called at the beginning of each batch. """
|
"""Called at the beginning of each batch."""
|
||||||
|
|
||||||
def on_batch_end(self, context: TrainContext):
|
def on_batch_end(self, context: TrainContext):
|
||||||
""" Called at the end of each batch. """
|
"""Called at the end of each batch."""
|
||||||
|
|
||||||
def on_error(self, context: TrainContext):
|
def on_error(self, context: TrainContext):
|
||||||
""" Called when an error occurs during training. """
|
"""Called when an error occurs during training."""
|
||||||
|
|
||||||
|
|
||||||
class GradientClippingCallback(TrainCallback):
|
class GradientClippingCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Gradient clipping callback for trainer.
|
Gradient clipping callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_grad_norm: float):
|
def __init__(self, max_grad_norm: float):
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
|
|
@ -73,6 +74,7 @@ class SchedulerCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Scheduler callback for trainer.
|
Scheduler callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -90,12 +92,13 @@ class CheckpointCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Checkpoint callback for trainer.
|
Checkpoint callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
save_dir: str,
|
save_dir: str,
|
||||||
interval: int,
|
interval: int,
|
||||||
weight_only: bool = False,
|
weight_only: bool = False,
|
||||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None
|
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
||||||
):
|
):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
|
|
@ -105,13 +108,17 @@ class CheckpointCallback(TrainCallback):
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
def _save_checkpoint(self, context: TrainContext):
|
def _save_checkpoint(self, context: TrainContext):
|
||||||
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
|
save_path = os.path.join(
|
||||||
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict()
|
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
||||||
|
)
|
||||||
|
state_dict = (
|
||||||
|
self.state_dict_fn(context.model)
|
||||||
|
if self.state_dict_fn
|
||||||
|
else context.model.state_dict()
|
||||||
|
)
|
||||||
|
|
||||||
context.checkpoint = Checkpoint(
|
context.checkpoint = Checkpoint(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict, epoch=context.epoch, iteration=context.iteration
|
||||||
epoch=context.epoch,
|
|
||||||
iteration=context.iteration
|
|
||||||
)
|
)
|
||||||
|
|
||||||
context.checkpoint.save(save_path)
|
context.checkpoint.save(save_path)
|
||||||
|
|
@ -133,6 +140,7 @@ class ProgressBarCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Progress bar callback for trainer.
|
Progress bar callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_epoch: int):
|
def __init__(self, num_epoch: int):
|
||||||
self.num_epoch = num_epoch
|
self.num_epoch = num_epoch
|
||||||
self.progress_bar: tqdm = None
|
self.progress_bar: tqdm = None
|
||||||
|
|
@ -141,16 +149,18 @@ class ProgressBarCallback(TrainCallback):
|
||||||
def on_epoch_begin(self, context: TrainContext):
|
def on_epoch_begin(self, context: TrainContext):
|
||||||
self.progress_bar = tqdm(
|
self.progress_bar = tqdm(
|
||||||
context.dataloader,
|
context.dataloader,
|
||||||
desc=f"Epoch {context.epoch+1}/{self.num_epoch}",
|
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
||||||
dynamic_ncols=True
|
dynamic_ncols=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
def on_batch_end(self, context: TrainContext):
|
def on_batch_end(self, context: TrainContext):
|
||||||
self.progress_bar.set_postfix({
|
self.progress_bar.set_postfix(
|
||||||
"loss": f"{context.loss:.4f}",
|
{
|
||||||
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}"
|
"loss": f"{context.loss:.4f}",
|
||||||
})
|
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
|
||||||
|
}
|
||||||
|
)
|
||||||
self.progress_bar.update(1)
|
self.progress_bar.update(1)
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
@ -163,15 +173,15 @@ class ProgressBarCallback(TrainCallback):
|
||||||
class MetricLoggerCallback(TrainCallback):
|
class MetricLoggerCallback(TrainCallback):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
log_dir:str,
|
log_dir: str,
|
||||||
save_interval:int,
|
save_interval: int,
|
||||||
log_interval:int=10,
|
log_interval: int = 10,
|
||||||
metrics:List[str]=None
|
metrics: List[str] = None,
|
||||||
):
|
):
|
||||||
self.last_log_iter = 0
|
self.last_log_iter = 0
|
||||||
self.save_interval = save_interval
|
self.save_interval = save_interval
|
||||||
self.log_interval = log_interval
|
self.log_interval = log_interval
|
||||||
self.metrics = metrics or ['loss', 'lr']
|
self.metrics = metrics or ["loss", "lr"]
|
||||||
|
|
||||||
self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs"
|
self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs"
|
||||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -179,22 +189,22 @@ class MetricLoggerCallback(TrainCallback):
|
||||||
self.log_cache = []
|
self.log_cache = []
|
||||||
|
|
||||||
self._metric_funcs = {
|
self._metric_funcs = {
|
||||||
'loss': ctx_get_loss,
|
"loss": ctx_get_loss,
|
||||||
'lr': ctx_get_lr,
|
"lr": ctx_get_lr,
|
||||||
'grad_norm': ctx_get_grad_norm,
|
"grad_norm": ctx_get_grad_norm,
|
||||||
'grad_std': ctx_get_grad_std,
|
"grad_std": ctx_get_grad_std,
|
||||||
'grad_max': ctx_get_grad_max,
|
"grad_max": ctx_get_grad_max,
|
||||||
'grad_min': ctx_get_grad_min,
|
"grad_min": ctx_get_grad_min,
|
||||||
'grad_mean': ctx_get_grad_mean,
|
"grad_mean": ctx_get_grad_mean,
|
||||||
'grad_nan_num': ctx_get_grad_nan_num
|
"grad_nan_num": ctx_get_grad_nan_num,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_log_data(self, context: TrainContext):
|
def _get_log_data(self, context: TrainContext):
|
||||||
return {
|
return {
|
||||||
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
"epoch": context.epoch,
|
"epoch": context.epoch,
|
||||||
"iter": context.iteration,
|
"iter": context.iteration,
|
||||||
**{m: self._metric_funcs[m](context) for m in self.metrics}
|
**{m: self._metric_funcs[m](context) for m in self.metrics},
|
||||||
}
|
}
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
@ -205,9 +215,9 @@ class MetricLoggerCallback(TrainCallback):
|
||||||
def _save_log(self, epoch, iter):
|
def _save_log(self, epoch, iter):
|
||||||
log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl"
|
log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl"
|
||||||
|
|
||||||
with open(log_file, 'w') as f:
|
with open(log_file, "w") as f:
|
||||||
for log in self.log_cache:
|
for log in self.log_cache:
|
||||||
f.write(json.dumps(log) + '\n')
|
f.write(json.dumps(log) + "\n")
|
||||||
|
|
||||||
def on_batch_end(self, context):
|
def on_batch_end(self, context):
|
||||||
if context.iteration % self.log_interval == 0:
|
if context.iteration % self.log_interval == 0:
|
||||||
|
|
@ -224,4 +234,3 @@ class MetricLoggerCallback(TrainCallback):
|
||||||
|
|
||||||
def on_error(self, context):
|
def on_error(self, context):
|
||||||
self._save_log(context.epoch, context.iteration)
|
self._save_log(context.epoch, context.iteration)
|
||||||
|
|
||||||
|
|
@ -72,7 +72,7 @@ class TrainContextBuilder:
|
||||||
data_source=config.dataset,
|
data_source=config.dataset,
|
||||||
start_epoch=self._context.epoch,
|
start_epoch=self._context.epoch,
|
||||||
start_iter=sampler_offset,
|
start_iter=sampler_offset,
|
||||||
seed=config.random_seed
|
seed=config.random_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
|
|
@ -81,7 +81,7 @@ class TrainContextBuilder:
|
||||||
sampler=resumeable_sampler,
|
sampler=resumeable_sampler,
|
||||||
num_workers=config.num_workers,
|
num_workers=config.num_workers,
|
||||||
pin_memory=config.pin_memory,
|
pin_memory=config.pin_memory,
|
||||||
prefetch_factor=config.prefetch_factor
|
prefetch_factor=config.prefetch_factor,
|
||||||
)
|
)
|
||||||
self._context.dataloader = dataloader
|
self._context.dataloader = dataloader
|
||||||
return self
|
return self
|
||||||
|
|
@ -91,7 +91,7 @@ class TrainContextBuilder:
|
||||||
model=self._context.model,
|
model=self._context.model,
|
||||||
train_type=self.config.strategy,
|
train_type=self.config.strategy,
|
||||||
device=get_current_device(),
|
device=get_current_device(),
|
||||||
**self.config.extra_kwargs
|
**self.config.extra_kwargs,
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from khaosz.trainer.train_callback import (
|
||||||
CheckpointCallback,
|
CheckpointCallback,
|
||||||
MetricLoggerCallback,
|
MetricLoggerCallback,
|
||||||
GradientClippingCallback,
|
GradientClippingCallback,
|
||||||
SchedulerCallback
|
SchedulerCallback,
|
||||||
)
|
)
|
||||||
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
||||||
from khaosz.data.serialization import Checkpoint
|
from khaosz.data.serialization import Checkpoint
|
||||||
|
|
@ -18,30 +18,32 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, train_config: TrainConfig, callbacks: Optional[List[TrainCallback]] = None
|
||||||
train_config: TrainConfig,
|
|
||||||
callbacks: Optional[List[TrainCallback]] = None
|
|
||||||
):
|
):
|
||||||
self.train_config = train_config
|
self.train_config = train_config
|
||||||
default_callbacks = self._get_default_callbacks()
|
default_callbacks = self._get_default_callbacks()
|
||||||
self.callbacks = default_callbacks + callbacks if callbacks else default_callbacks
|
self.callbacks = (
|
||||||
|
default_callbacks + callbacks if callbacks else default_callbacks
|
||||||
|
)
|
||||||
|
|
||||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||||
train_config = self.train_config
|
train_config = self.train_config
|
||||||
return [
|
return [
|
||||||
ProgressBarCallback(train_config.n_epoch),
|
ProgressBarCallback(train_config.n_epoch),
|
||||||
CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
|
CheckpointCallback(train_config.ckpt_dir, train_config.ckpt_interval),
|
||||||
MetricLoggerCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
|
MetricLoggerCallback(train_config.ckpt_dir, train_config.ckpt_interval),
|
||||||
GradientClippingCallback(train_config.max_grad_norm),
|
GradientClippingCallback(train_config.max_grad_norm),
|
||||||
SchedulerCallback(),
|
SchedulerCallback(),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||||
return (TrainContextBuilder(self.train_config)
|
return (
|
||||||
.with_checkpoint(checkpoint)
|
TrainContextBuilder(self.train_config)
|
||||||
.with_dataloader()
|
.with_checkpoint(checkpoint)
|
||||||
.with_strategy()
|
.with_dataloader()
|
||||||
.build())
|
.with_strategy()
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
|
|
@ -59,23 +61,23 @@ class Trainer:
|
||||||
master_port=config.master_port,
|
master_port=config.master_port,
|
||||||
device_type=config.device_type,
|
device_type=config.device_type,
|
||||||
device_ids=config.device_ids,
|
device_ids=config.device_ids,
|
||||||
checkpoint=checkpoint
|
checkpoint=checkpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
|
||||||
context = self._build_context(checkpoint)
|
context = self._build_context(checkpoint)
|
||||||
self._call_callbacks('on_train_begin', context)
|
self._call_callbacks("on_train_begin", context)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context.model.train()
|
context.model.train()
|
||||||
# 1.epoch
|
# 1.epoch
|
||||||
for epoch in range(context.epoch, self.train_config.n_epoch):
|
for epoch in range(context.epoch, self.train_config.n_epoch):
|
||||||
context.epoch = epoch
|
context.epoch = epoch
|
||||||
self._call_callbacks('on_epoch_begin', context)
|
self._call_callbacks("on_epoch_begin", context)
|
||||||
|
|
||||||
for batch in context.dataloader:
|
for batch in context.dataloader:
|
||||||
# 3. batch
|
# 3. batch
|
||||||
self._call_callbacks('on_batch_begin', context)
|
self._call_callbacks("on_batch_begin", context)
|
||||||
loss = context.strategy(batch)
|
loss = context.strategy(batch)
|
||||||
context.loss = loss.item()
|
context.loss = loss.item()
|
||||||
context.iteration += 1
|
context.iteration += 1
|
||||||
|
|
@ -84,20 +86,20 @@ class Trainer:
|
||||||
stand_loss = loss / self.train_config.accumulation_steps
|
stand_loss = loss / self.train_config.accumulation_steps
|
||||||
stand_loss.backward()
|
stand_loss.backward()
|
||||||
|
|
||||||
self._call_callbacks('on_batch_end', context)
|
self._call_callbacks("on_batch_end", context)
|
||||||
|
|
||||||
if context.iteration % self.train_config.accumulation_steps == 0:
|
if context.iteration % self.train_config.accumulation_steps == 0:
|
||||||
# 2. step
|
# 2. step
|
||||||
self._call_callbacks('on_step_begin', context)
|
self._call_callbacks("on_step_begin", context)
|
||||||
context.optimizer.step()
|
context.optimizer.step()
|
||||||
context.optimizer.zero_grad()
|
context.optimizer.zero_grad()
|
||||||
self._call_callbacks('on_step_end', context)
|
self._call_callbacks("on_step_end", context)
|
||||||
|
|
||||||
self._call_callbacks('on_epoch_end', context)
|
self._call_callbacks("on_epoch_end", context)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
||||||
self._call_callbacks('on_error', context)
|
self._call_callbacks("on_error", context)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
self._call_callbacks('on_train_end', context)
|
self._call_callbacks("on_train_end", context)
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ classifiers = [
|
||||||
urls = { Homepage = "https://github.com/ViperEkura/KHAOSZ" }
|
urls = { Homepage = "https://github.com/ViperEkura/KHAOSZ" }
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = ["pytest==9.0.2"]
|
dev = ["pytest==9.0.2", "ruff"]
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["."]
|
where = ["."]
|
||||||
|
|
@ -36,3 +36,12 @@ extra-index-url = "https://download.pytorch.org/whl/cu126"
|
||||||
|
|
||||||
[tool.setuptools.dynamic]
|
[tool.setuptools.dynamic]
|
||||||
version = { attr = "khaosz.__version__" }
|
version = { attr = "khaosz.__version__" }
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py312"
|
||||||
|
|
||||||
|
[tool.ruff.format]
|
||||||
|
quote-style = "double"
|
||||||
|
indent-style = "space"
|
||||||
|
skip-magic-trailing-comma = false
|
||||||
|
line-ending = "auto"
|
||||||
|
|
@ -24,7 +24,7 @@ class RandomDataset(Dataset):
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return {
|
return {
|
||||||
"input_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
|
"input_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
|
||||||
"target_ids": torch.randint(0, self.vocab_size, (self.max_length,))
|
"target_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -65,7 +65,7 @@ class EarlyStoppingDataset(Dataset):
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": torch.randint(0, 1000, (64,)),
|
"input_ids": torch.randint(0, 1000, (64,)),
|
||||||
"target_ids": torch.randint(0, 1000, (64,))
|
"target_ids": torch.randint(0, 1000, (64,)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -91,10 +91,10 @@ def base_test_env(request: pytest.FixtureRequest):
|
||||||
"dim_ffn": dim_ffn,
|
"dim_ffn": dim_ffn,
|
||||||
"max_len": 1024,
|
"max_len": 1024,
|
||||||
"n_layers": 4,
|
"n_layers": 4,
|
||||||
"norm_eps": 1e-5
|
"norm_eps": 1e-5,
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config, f)
|
json.dump(config, f)
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
transformer_config = ModelConfig().load(config_path)
|
transformer_config = ModelConfig().load(config_path)
|
||||||
|
|
@ -112,16 +112,19 @@ def base_test_env(request: pytest.FixtureRequest):
|
||||||
|
|
||||||
shutil.rmtree(test_dir)
|
shutil.rmtree(test_dir)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def random_dataset():
|
def random_dataset():
|
||||||
dataset = RandomDataset()
|
dataset = RandomDataset()
|
||||||
yield dataset
|
yield dataset
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def multi_turn_dataset():
|
def multi_turn_dataset():
|
||||||
dataset = MultiTurnDataset()
|
dataset = MultiTurnDataset()
|
||||||
yield dataset
|
yield dataset
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def early_stopping_dataset():
|
def early_stopping_dataset():
|
||||||
dataset = EarlyStoppingDataset()
|
dataset = EarlyStoppingDataset()
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
from khaosz.data.serialization import Checkpoint
|
from khaosz.data.serialization import Checkpoint
|
||||||
from khaosz.parallel.setup import get_rank, spawn_parallel_fn
|
from khaosz.parallel.setup import get_rank, spawn_parallel_fn
|
||||||
|
|
||||||
|
|
||||||
def test_single_process():
|
def test_single_process():
|
||||||
model = torch.nn.Linear(10, 5)
|
model = torch.nn.Linear(10, 5)
|
||||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||||
|
|
@ -14,7 +15,6 @@ def test_single_process():
|
||||||
|
|
||||||
for epoch in range(3):
|
for epoch in range(3):
|
||||||
for iteration in range(10):
|
for iteration in range(10):
|
||||||
|
|
||||||
x = torch.randn(32, 10)
|
x = torch.randn(32, 10)
|
||||||
y = torch.randn(32, 5)
|
y = torch.randn(32, 5)
|
||||||
loss = model(x).mean()
|
loss = model(x).mean()
|
||||||
|
|
@ -24,11 +24,7 @@ def test_single_process():
|
||||||
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
checkpoint = Checkpoint(
|
checkpoint = Checkpoint(state_dict=model.state_dict(), epoch=3, iteration=30)
|
||||||
state_dict=model.state_dict(),
|
|
||||||
epoch=3,
|
|
||||||
iteration=30
|
|
||||||
)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
checkpoint.save(tmpdir)
|
checkpoint.save(tmpdir)
|
||||||
|
|
@ -37,6 +33,8 @@ def test_single_process():
|
||||||
|
|
||||||
assert loaded_checkpoint.epoch == 3
|
assert loaded_checkpoint.epoch == 3
|
||||||
assert loaded_checkpoint.iteration == 30
|
assert loaded_checkpoint.iteration == 30
|
||||||
|
|
||||||
|
|
||||||
def simple_training():
|
def simple_training():
|
||||||
model = torch.nn.Linear(10, 5)
|
model = torch.nn.Linear(10, 5)
|
||||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||||
|
|
@ -66,19 +64,14 @@ def simple_training():
|
||||||
else:
|
else:
|
||||||
shared_dir = None
|
shared_dir = None
|
||||||
|
|
||||||
|
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dir_list = [shared_dir]
|
dir_list = [shared_dir]
|
||||||
dist.broadcast_object_list(dir_list, src=0)
|
dist.broadcast_object_list(dir_list, src=0)
|
||||||
shared_dir = dir_list[0]
|
shared_dir = dir_list[0]
|
||||||
|
|
||||||
|
|
||||||
loaded = Checkpoint.load(shared_dir)
|
loaded = Checkpoint.load(shared_dir)
|
||||||
assert loaded.epoch == 2
|
assert loaded.epoch == 2
|
||||||
|
|
||||||
|
|
||||||
def test_multi_process():
|
def test_multi_process():
|
||||||
spawn_parallel_fn(
|
spawn_parallel_fn(simple_training, world_size=2, backend="gloo")
|
||||||
simple_training,
|
|
||||||
world_size=2,
|
|
||||||
backend="gloo"
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from khaosz.data.serialization import save_h5
|
||||||
from khaosz.data.dataset import *
|
from khaosz.data.dataset import *
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_loader_random_paths(base_test_env):
|
def test_dataset_loader_random_paths(base_test_env):
|
||||||
"""Test dataset loader with multiple random paths"""
|
"""Test dataset loader with multiple random paths"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
@ -16,7 +15,10 @@ def test_dataset_loader_random_paths(base_test_env):
|
||||||
for i in range(num_files):
|
for i in range(num_files):
|
||||||
seq_length = np.random.randint(200, 400)
|
seq_length = np.random.randint(200, 400)
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64) for _ in range(10)],
|
"sequence": [
|
||||||
|
torch.randint(0, 1000, (seq_length,), dtype=torch.int64)
|
||||||
|
for _ in range(10)
|
||||||
|
],
|
||||||
}
|
}
|
||||||
save_h5(test_dir, f"data_{i}", dummy_data)
|
save_h5(test_dir, f"data_{i}", dummy_data)
|
||||||
|
|
||||||
|
|
@ -49,7 +51,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
||||||
"chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
"chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||||
"rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
"rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||||
"chosen_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
"chosen_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
||||||
"rejected_mask": [torch.ones(seq_length, dtype=torch.bool)]
|
"rejected_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
||||||
}
|
}
|
||||||
|
|
||||||
save_h5(test_dir, "dpo_data", dummy_data)
|
save_h5(test_dir, "dpo_data", dummy_data)
|
||||||
|
|
@ -62,7 +64,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert dpo_dataset is not None
|
assert dpo_dataset is not None
|
||||||
assert hasattr(dpo_dataset, 'fetcher')
|
assert hasattr(dpo_dataset, "fetcher")
|
||||||
assert len(dpo_dataset) > 0
|
assert len(dpo_dataset) > 0
|
||||||
|
|
||||||
# Test that we can get DPO items without errors
|
# Test that we can get DPO items without errors
|
||||||
|
|
@ -85,7 +87,7 @@ def test_sft_dataset_with_random_data(base_test_env):
|
||||||
|
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||||
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)]
|
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
||||||
}
|
}
|
||||||
|
|
||||||
save_h5(test_dir, "sft_data", dummy_data)
|
save_h5(test_dir, "sft_data", dummy_data)
|
||||||
|
|
@ -98,7 +100,7 @@ def test_sft_dataset_with_random_data(base_test_env):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert sft_dataset is not None
|
assert sft_dataset is not None
|
||||||
assert hasattr(sft_dataset, 'fetcher')
|
assert hasattr(sft_dataset, "fetcher")
|
||||||
assert len(sft_dataset) > 0
|
assert len(sft_dataset) > 0
|
||||||
|
|
||||||
# Test that we can get SFT items without errors
|
# Test that we can get SFT items without errors
|
||||||
|
|
@ -121,15 +123,12 @@ def test_dataset_with_custom_stride(base_test_env):
|
||||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||||
}
|
}
|
||||||
|
|
||||||
save_h5(test_dir,"stride_test_data", dummy_data)
|
save_h5(test_dir, "stride_test_data", dummy_data)
|
||||||
|
|
||||||
# Test with custom stride
|
# Test with custom stride
|
||||||
custom_stride = 32
|
custom_stride = 32
|
||||||
dataset = DatasetLoader.load(
|
dataset = DatasetLoader.load(
|
||||||
train_type="seq",
|
train_type="seq", load_path=test_dir, window_size=64, stride=custom_stride
|
||||||
load_path=test_dir,
|
|
||||||
window_size=64,
|
|
||||||
stride=custom_stride
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert dataset is not None
|
assert dataset is not None
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
from khaosz.data import *
|
from khaosz.data import *
|
||||||
|
|
||||||
|
|
||||||
def test_random_sampler_consistency(random_dataset):
|
def test_random_sampler_consistency(random_dataset):
|
||||||
"""Test RandomSampler produces consistent results with same seed"""
|
"""Test RandomSampler produces consistent results with same seed"""
|
||||||
dataset = random_dataset
|
dataset = random_dataset
|
||||||
|
|
@ -14,6 +15,7 @@ def test_random_sampler_consistency(random_dataset):
|
||||||
|
|
||||||
assert indices1 == indices2
|
assert indices1 == indices2
|
||||||
|
|
||||||
|
|
||||||
def test_random_sampler_different_seeds(random_dataset):
|
def test_random_sampler_different_seeds(random_dataset):
|
||||||
"""Test RandomSampler produces different results with different seeds"""
|
"""Test RandomSampler produces different results with different seeds"""
|
||||||
dataset = random_dataset
|
dataset = random_dataset
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from khaosz.data import *
|
||||||
from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore
|
from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore
|
||||||
from tokenizers import pre_tokenizers
|
from tokenizers import pre_tokenizers
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_env(request: pytest.FixtureRequest):
|
def test_env(request: pytest.FixtureRequest):
|
||||||
func_name = request.function.__name__
|
func_name = request.function.__name__
|
||||||
|
|
@ -28,9 +29,9 @@ def test_env(request: pytest.FixtureRequest):
|
||||||
"dim_ffn": 256,
|
"dim_ffn": 256,
|
||||||
"max_len": 64,
|
"max_len": 64,
|
||||||
"n_layers": 2,
|
"n_layers": 2,
|
||||||
"norm_eps": 1e-5
|
"norm_eps": 1e-5,
|
||||||
}
|
}
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config, f)
|
json.dump(config, f)
|
||||||
|
|
||||||
tokenizer = BpeTokenizer()
|
tokenizer = BpeTokenizer()
|
||||||
|
|
@ -51,30 +52,40 @@ def test_env(request: pytest.FixtureRequest):
|
||||||
|
|
||||||
shutil.rmtree(test_dir)
|
shutil.rmtree(test_dir)
|
||||||
|
|
||||||
|
|
||||||
def test_model_parameter(test_env):
|
def test_model_parameter(test_env):
|
||||||
save_dir = os.path.join(test_env["test_dir"], "save")
|
save_dir = os.path.join(test_env["test_dir"], "save")
|
||||||
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])
|
model_param = ModelParameter(
|
||||||
|
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
|
||||||
|
)
|
||||||
ModelParameter.save(model_param, save_dir)
|
ModelParameter.save(model_param, save_dir)
|
||||||
|
|
||||||
assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
|
assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
|
||||||
assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))
|
assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))
|
||||||
assert os.path.exists(os.path.join(save_dir, "config.json"))
|
assert os.path.exists(os.path.join(save_dir, "config.json"))
|
||||||
|
|
||||||
|
|
||||||
# transformer
|
# transformer
|
||||||
def test_transformer(test_env):
|
def test_transformer(test_env):
|
||||||
model = test_env["model"]
|
model = test_env["model"]
|
||||||
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size,
|
input_ids = torch.randint(
|
||||||
(4, test_env["transformer_config"].max_len))
|
0,
|
||||||
|
test_env["transformer_config"].vocab_size,
|
||||||
|
(4, test_env["transformer_config"].max_len),
|
||||||
|
)
|
||||||
output_logits = model(input_ids)["logits"]
|
output_logits = model(input_ids)["logits"]
|
||||||
target_shape = (4, test_env["transformer_config"].max_len, test_env["transformer_config"].vocab_size)
|
target_shape = (
|
||||||
|
4,
|
||||||
|
test_env["transformer_config"].max_len,
|
||||||
|
test_env["transformer_config"].vocab_size,
|
||||||
|
)
|
||||||
assert output_logits.shape == target_shape
|
assert output_logits.shape == target_shape
|
||||||
|
|
||||||
|
|
||||||
# generator
|
# generator
|
||||||
def test_embedding_encoder_core(test_env):
|
def test_embedding_encoder_core(test_env):
|
||||||
parameter = ModelParameter(
|
parameter = ModelParameter(
|
||||||
test_env["model"],
|
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
|
||||||
test_env["tokenizer"],
|
|
||||||
test_env["transformer_config"]
|
|
||||||
)
|
)
|
||||||
encoder = EmbeddingEncoderCore(parameter)
|
encoder = EmbeddingEncoderCore(parameter)
|
||||||
|
|
||||||
|
|
@ -82,16 +93,14 @@ def test_embedding_encoder_core(test_env):
|
||||||
assert isinstance(single_emb, torch.Tensor)
|
assert isinstance(single_emb, torch.Tensor)
|
||||||
assert single_emb.shape[-1] == test_env["transformer_config"].dim
|
assert single_emb.shape[-1] == test_env["transformer_config"].dim
|
||||||
|
|
||||||
|
|
||||||
batch_emb = encoder.encode(["测试1", "测试2"])
|
batch_emb = encoder.encode(["测试1", "测试2"])
|
||||||
assert isinstance(batch_emb, list)
|
assert isinstance(batch_emb, list)
|
||||||
assert len(batch_emb) == 2
|
assert len(batch_emb) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_generator_core(test_env):
|
def test_generator_core(test_env):
|
||||||
parameter = ModelParameter(
|
parameter = ModelParameter(
|
||||||
test_env["model"],
|
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
|
||||||
test_env["tokenizer"],
|
|
||||||
test_env["transformer_config"]
|
|
||||||
)
|
)
|
||||||
generator = GeneratorCore(parameter)
|
generator = GeneratorCore(parameter)
|
||||||
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10))
|
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10))
|
||||||
|
|
@ -102,7 +111,7 @@ def test_generator_core(test_env):
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
attn_mask=None,
|
attn_mask=None,
|
||||||
kv_caches=None,
|
kv_caches=None,
|
||||||
start_pos=0
|
start_pos=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert next_token_id.shape == (4, 1)
|
assert next_token_id.shape == (4, 1)
|
||||||
|
|
|
||||||
|
|
@ -22,17 +22,13 @@ def transformer_test_env():
|
||||||
"dim_ffn": 256,
|
"dim_ffn": 256,
|
||||||
"max_len": 64,
|
"max_len": 64,
|
||||||
"n_layers": 2,
|
"n_layers": 2,
|
||||||
"norm_eps": 1e-5
|
"norm_eps": 1e-5,
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config, f)
|
json.dump(config, f)
|
||||||
|
|
||||||
yield {
|
yield {"test_dir": test_dir, "config_path": config_path, "config": config}
|
||||||
"test_dir": test_dir,
|
|
||||||
"config_path": config_path,
|
|
||||||
"config": config
|
|
||||||
}
|
|
||||||
|
|
||||||
if os.path.exists(test_dir):
|
if os.path.exists(test_dir):
|
||||||
try:
|
try:
|
||||||
|
|
@ -50,7 +46,7 @@ def test_tie_weight_init(transformer_test_env):
|
||||||
# case 1: tie weight
|
# case 1: tie weight
|
||||||
config_data["tie_weight"] = True
|
config_data["tie_weight"] = True
|
||||||
|
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
config = ModelConfig().load(config_path)
|
config = ModelConfig().load(config_path)
|
||||||
|
|
@ -68,7 +64,7 @@ def test_tie_weight_init(transformer_test_env):
|
||||||
# case 2: not tie weight
|
# case 2: not tie weight
|
||||||
config_data["tie_weight"] = False
|
config_data["tie_weight"] = False
|
||||||
|
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
config = ModelConfig().load(config_path)
|
config = ModelConfig().load(config_path)
|
||||||
|
|
@ -83,6 +79,7 @@ def test_tie_weight_init(transformer_test_env):
|
||||||
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
assert not torch.equal(model.lm_head.weight, original_weight)
|
assert not torch.equal(model.lm_head.weight, original_weight)
|
||||||
|
|
||||||
|
|
||||||
def test_model_save_load_with_tie_weight(transformer_test_env):
|
def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
test_dir = transformer_test_env["test_dir"]
|
test_dir = transformer_test_env["test_dir"]
|
||||||
model_path = os.path.join(test_dir, "model.safetensors")
|
model_path = os.path.join(test_dir, "model.safetensors")
|
||||||
|
|
@ -93,7 +90,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
config_data["tie_weight"] = True
|
config_data["tie_weight"] = True
|
||||||
config_path = os.path.join(test_dir, "config.json")
|
config_path = os.path.join(test_dir, "config.json")
|
||||||
|
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
config = ModelConfig().load(config_path)
|
config = ModelConfig().load(config_path)
|
||||||
|
|
@ -111,7 +108,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
|
|
||||||
# case 2: not tie weight (form tie-weight state dict load)
|
# case 2: not tie weight (form tie-weight state dict load)
|
||||||
config_data["tie_weight"] = False
|
config_data["tie_weight"] = False
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
loaded_config = ModelConfig().load(config_path)
|
loaded_config = ModelConfig().load(config_path)
|
||||||
|
|
@ -121,4 +118,3 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
||||||
assert "lm_head.weight" in model.state_dict()
|
assert "lm_head.weight" in model.state_dict()
|
||||||
|
|
||||||
|
|
@ -1,16 +1,14 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from khaosz.parallel import (
|
from khaosz.parallel import get_rank, only_on_rank, spawn_parallel_fn
|
||||||
get_rank,
|
|
||||||
only_on_rank,
|
|
||||||
spawn_parallel_fn
|
|
||||||
)
|
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
def _test_only_on_rank_helper():
|
def _test_only_on_rank_helper():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def only_on_rank():
|
def only_on_rank():
|
||||||
result = _test_only_on_rank_helper()
|
result = _test_only_on_rank_helper()
|
||||||
if get_rank() == 0:
|
if get_rank() == 0:
|
||||||
|
|
@ -18,22 +16,17 @@ def only_on_rank():
|
||||||
else:
|
else:
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
def all_reduce():
|
def all_reduce():
|
||||||
x = torch.tensor([get_rank()], dtype=torch.int)
|
x = torch.tensor([get_rank()], dtype=torch.int)
|
||||||
dist.all_reduce(x, op=dist.ReduceOp.SUM)
|
dist.all_reduce(x, op=dist.ReduceOp.SUM)
|
||||||
expected_sum = sum(range(dist.get_world_size()))
|
expected_sum = sum(range(dist.get_world_size()))
|
||||||
assert x.item() == expected_sum
|
assert x.item() == expected_sum
|
||||||
|
|
||||||
|
|
||||||
def test_spawn_only_on_rank():
|
def test_spawn_only_on_rank():
|
||||||
spawn_parallel_fn(
|
spawn_parallel_fn(only_on_rank, world_size=2, backend="gloo")
|
||||||
only_on_rank,
|
|
||||||
world_size=2,
|
|
||||||
backend="gloo"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_spawn_all_reduce():
|
def test_spawn_all_reduce():
|
||||||
spawn_parallel_fn(
|
spawn_parallel_fn(all_reduce, world_size=2, backend="gloo")
|
||||||
all_reduce,
|
|
||||||
world_size=2,
|
|
||||||
backend="gloo"
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -3,57 +3,48 @@ import torch
|
||||||
from khaosz.config import *
|
from khaosz.config import *
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
|
|
||||||
|
|
||||||
def test_callback_integration(base_test_env, random_dataset):
|
def test_callback_integration(base_test_env, random_dataset):
|
||||||
"""Test that all callbacks are properly integrated"""
|
"""Test that all callbacks are properly integrated"""
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||||
warmup_steps=10,
|
|
||||||
total_steps=20
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
model=base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
strategy='seq',
|
strategy="seq",
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
checkpoint_interval=3,
|
ckpt_interval=3,
|
||||||
accumulation_steps=1,
|
accumulation_steps=1,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
device_type=base_test_env["device"]
|
device_type=base_test_env["device"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Create custom callbacks to track calls
|
# Create custom callbacks to track calls
|
||||||
callback_calls = []
|
callback_calls = []
|
||||||
|
|
||||||
class TrackingCallback(TrainCallback):
|
class TrackingCallback(TrainCallback):
|
||||||
def on_train_begin(self, context):
|
def on_train_begin(self, context):
|
||||||
callback_calls.append('on_train_begin')
|
callback_calls.append("on_train_begin")
|
||||||
|
|
||||||
def on_batch_end(self, context):
|
def on_batch_end(self, context):
|
||||||
callback_calls.append('on_batch_end')
|
callback_calls.append("on_batch_end")
|
||||||
|
|
||||||
def on_epoch_end(self, context):
|
def on_epoch_end(self, context):
|
||||||
callback_calls.append('on_epoch_end')
|
callback_calls.append("on_epoch_end")
|
||||||
|
|
||||||
|
trainer = Trainer(train_config, callbacks=[TrackingCallback()])
|
||||||
|
|
||||||
trainer = Trainer(
|
|
||||||
train_config,
|
|
||||||
callbacks=[TrackingCallback()]
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
# Verify callbacks were called
|
# Verify callbacks were called
|
||||||
assert 'on_train_begin' in callback_calls
|
assert "on_train_begin" in callback_calls
|
||||||
assert 'on_batch_end' in callback_calls
|
assert "on_batch_end" in callback_calls
|
||||||
assert 'on_epoch_end' in callback_calls
|
assert "on_epoch_end" in callback_calls
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from khaosz.config import *
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
from khaosz.data.serialization import Checkpoint
|
from khaosz.data.serialization import Checkpoint
|
||||||
|
|
||||||
|
|
||||||
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
"""Simulate early stopping behavior"""
|
"""Simulate early stopping behavior"""
|
||||||
|
|
||||||
|
|
@ -19,13 +20,13 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
model=base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
dataset=early_stopping_dataset,
|
dataset=early_stopping_dataset,
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
n_epoch=2,
|
n_epoch=2,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
checkpoint_interval=1,
|
ckpt_interval=1,
|
||||||
accumulation_steps=2,
|
accumulation_steps=2,
|
||||||
random_seed=np.random.randint(1e4),
|
random_seed=np.random.randint(1e4),
|
||||||
device_type=base_test_env["device"]
|
device_type=base_test_env["device"],
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
|
|
|
||||||
|
|
@ -20,14 +20,14 @@ def test_schedule_factory_random_configs():
|
||||||
CosineScheduleConfig(
|
CosineScheduleConfig(
|
||||||
warmup_steps=np.random.randint(50, 200),
|
warmup_steps=np.random.randint(50, 200),
|
||||||
total_steps=np.random.randint(1000, 5000),
|
total_steps=np.random.randint(1000, 5000),
|
||||||
min_rate=np.random.uniform(0.01, 0.1)
|
min_rate=np.random.uniform(0.01, 0.1),
|
||||||
),
|
),
|
||||||
SGDRScheduleConfig(
|
SGDRScheduleConfig(
|
||||||
warmup_steps=np.random.randint(50, 200),
|
warmup_steps=np.random.randint(50, 200),
|
||||||
cycle_length=np.random.randint(500, 2000),
|
cycle_length=np.random.randint(500, 2000),
|
||||||
t_mult=np.random.randint(1, 3),
|
t_mult=np.random.randint(1, 3),
|
||||||
min_rate=np.random.uniform(0.01, 0.1)
|
min_rate=np.random.uniform(0.01, 0.1),
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
for config in schedule_configs:
|
for config in schedule_configs:
|
||||||
|
|
@ -41,7 +41,9 @@ def test_schedule_factory_random_configs():
|
||||||
if isinstance(config, CosineScheduleConfig):
|
if isinstance(config, CosineScheduleConfig):
|
||||||
assert isinstance(scheduler, CosineScheduler)
|
assert isinstance(scheduler, CosineScheduler)
|
||||||
assert scheduler.warmup_steps == config.warmup_steps
|
assert scheduler.warmup_steps == config.warmup_steps
|
||||||
assert scheduler.lr_decay_steps == config.total_steps - config.warmup_steps
|
assert (
|
||||||
|
scheduler.lr_decay_steps == config.total_steps - config.warmup_steps
|
||||||
|
)
|
||||||
assert scheduler.min_rate == config.min_rate
|
assert scheduler.min_rate == config.min_rate
|
||||||
elif isinstance(config, SGDRScheduleConfig):
|
elif isinstance(config, SGDRScheduleConfig):
|
||||||
assert isinstance(scheduler, SGDRScheduler)
|
assert isinstance(scheduler, SGDRScheduler)
|
||||||
|
|
@ -52,8 +54,8 @@ def test_schedule_factory_random_configs():
|
||||||
|
|
||||||
# Test scheduler state dict functionality
|
# Test scheduler state dict functionality
|
||||||
state_dict = scheduler.state_dict()
|
state_dict = scheduler.state_dict()
|
||||||
assert 'warmup_steps' in state_dict
|
assert "warmup_steps" in state_dict
|
||||||
assert 'min_rate' in state_dict
|
assert "min_rate" in state_dict
|
||||||
|
|
||||||
# Test scheduler step functionality
|
# Test scheduler step functionality
|
||||||
initial_lr = scheduler.get_last_lr()
|
initial_lr = scheduler.get_last_lr()
|
||||||
|
|
|
||||||
|
|
@ -6,15 +6,13 @@ from khaosz.config import *
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
from khaosz.data.dataset import *
|
from khaosz.data.dataset import *
|
||||||
|
|
||||||
|
|
||||||
def test_different_batch_sizes(base_test_env, random_dataset):
|
def test_different_batch_sizes(base_test_env, random_dataset):
|
||||||
"""Test training with different batch sizes"""
|
"""Test training with different batch sizes"""
|
||||||
batch_sizes = [1, 2, 4, 8]
|
batch_sizes = [1, 2, 4, 8]
|
||||||
|
|
||||||
for batch_size in batch_sizes:
|
for batch_size in batch_sizes:
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||||
warmup_steps=10,
|
|
||||||
total_steps=20
|
|
||||||
)
|
|
||||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||||
|
|
||||||
|
|
@ -24,27 +22,25 @@ def test_different_batch_sizes(base_test_env, random_dataset):
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
checkpoint_interval=5,
|
ckpt_interval=5,
|
||||||
accumulation_steps=1,
|
accumulation_steps=1,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=np.random.randint(1000),
|
random_seed=np.random.randint(1000),
|
||||||
device_type=base_test_env["device"]
|
device_type=base_test_env["device"],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert train_config.batch_size == batch_size
|
assert train_config.batch_size == batch_size
|
||||||
|
|
||||||
|
|
||||||
def test_gradient_accumulation(base_test_env, random_dataset):
|
def test_gradient_accumulation(base_test_env, random_dataset):
|
||||||
"""Test training with different gradient accumulation steps"""
|
"""Test training with different gradient accumulation steps"""
|
||||||
accumulation_steps_list = [1, 2, 4]
|
accumulation_steps_list = [1, 2, 4]
|
||||||
|
|
||||||
for accumulation_steps in accumulation_steps_list:
|
for accumulation_steps in accumulation_steps_list:
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||||
warmup_steps=10,
|
|
||||||
total_steps=20
|
|
||||||
)
|
|
||||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||||
|
|
||||||
|
|
@ -54,14 +50,14 @@ def test_gradient_accumulation(base_test_env, random_dataset):
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
checkpoint_interval=10,
|
ckpt_interval=10,
|
||||||
accumulation_steps=accumulation_steps,
|
accumulation_steps=accumulation_steps,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
device_type=base_test_env["device"]
|
device_type=base_test_env["device"],
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
|
|
@ -69,20 +65,18 @@ def test_gradient_accumulation(base_test_env, random_dataset):
|
||||||
|
|
||||||
assert train_config.accumulation_steps == accumulation_steps
|
assert train_config.accumulation_steps == accumulation_steps
|
||||||
|
|
||||||
|
|
||||||
def test_memory_efficient_training(base_test_env, random_dataset):
|
def test_memory_efficient_training(base_test_env, random_dataset):
|
||||||
"""Test training with memory-efficient configurations"""
|
"""Test training with memory-efficient configurations"""
|
||||||
# Test with smaller batch sizes and gradient checkpointing
|
# Test with smaller batch sizes and gradient checkpointing
|
||||||
small_batch_configs = [
|
small_batch_configs = [
|
||||||
{"batch_size": 1, "accumulation_steps": 8},
|
{"batch_size": 1, "accumulation_steps": 8},
|
||||||
{"batch_size": 2, "accumulation_steps": 4},
|
{"batch_size": 2, "accumulation_steps": 4},
|
||||||
{"batch_size": 4, "accumulation_steps": 2}
|
{"batch_size": 4, "accumulation_steps": 2},
|
||||||
]
|
]
|
||||||
|
|
||||||
for config in small_batch_configs:
|
for config in small_batch_configs:
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
|
||||||
warmup_steps=10,
|
|
||||||
total_steps=20
|
|
||||||
)
|
|
||||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||||
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
|
||||||
|
|
||||||
|
|
@ -92,14 +86,14 @@ def test_memory_efficient_training(base_test_env, random_dataset):
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_size=config["batch_size"],
|
batch_size=config["batch_size"],
|
||||||
checkpoint_interval=5,
|
ckpt_interval=5,
|
||||||
accumulation_steps=config["accumulation_steps"],
|
accumulation_steps=config["accumulation_steps"],
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
device_type=base_test_env["device"]
|
device_type=base_test_env["device"],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert train_config.accumulation_steps == config["accumulation_steps"]
|
assert train_config.accumulation_steps == config["accumulation_steps"]
|
||||||
|
|
@ -17,7 +17,7 @@ class GenerationBenchmark:
|
||||||
self,
|
self,
|
||||||
config: ModelConfig,
|
config: ModelConfig,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.float16
|
dtype: torch.dtype = torch.float16,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
@ -28,7 +28,13 @@ class GenerationBenchmark:
|
||||||
def _initialize_kv_cache(self, batch_size: int) -> list:
|
def _initialize_kv_cache(self, batch_size: int) -> list:
|
||||||
"""初始化KV缓存"""
|
"""初始化KV缓存"""
|
||||||
config = self.config
|
config = self.config
|
||||||
shape = (batch_size, config.max_len, config.n_layers, config.n_kv_heads, config.dim // config.n_heads)
|
shape = (
|
||||||
|
batch_size,
|
||||||
|
config.max_len,
|
||||||
|
config.n_layers,
|
||||||
|
config.n_kv_heads,
|
||||||
|
config.dim // config.n_heads,
|
||||||
|
)
|
||||||
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
||||||
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
||||||
return (k_cache, v_cache)
|
return (k_cache, v_cache)
|
||||||
|
|
@ -39,7 +45,7 @@ class GenerationBenchmark:
|
||||||
high=self.config.vocab_size,
|
high=self.config.vocab_size,
|
||||||
size=(batch_size, prompt_length),
|
size=(batch_size, prompt_length),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.long
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_ids = torch.randint(
|
gen_ids = torch.randint(
|
||||||
|
|
@ -47,7 +53,7 @@ class GenerationBenchmark:
|
||||||
high=self.config.vocab_size,
|
high=self.config.vocab_size,
|
||||||
size=(batch_size, total_length - prompt_length),
|
size=(batch_size, total_length - prompt_length),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.long
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt_ids, gen_ids
|
return prompt_ids, gen_ids
|
||||||
|
|
@ -61,7 +67,9 @@ class GenerationBenchmark:
|
||||||
) -> BenchmarkResult:
|
) -> BenchmarkResult:
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length)
|
prompt_ids, _ = self._prepare_inputs(
|
||||||
|
batch_size, prompt_length, prompt_length
|
||||||
|
)
|
||||||
_ = self.model(prompt_ids)
|
_ = self.model(prompt_ids)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
@ -70,7 +78,9 @@ class GenerationBenchmark:
|
||||||
total_tokens = batch_size * prompt_length * num_trials
|
total_tokens = batch_size * prompt_length * num_trials
|
||||||
|
|
||||||
for trial in range(num_trials):
|
for trial in range(num_trials):
|
||||||
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length)
|
prompt_ids, _ = self._prepare_inputs(
|
||||||
|
batch_size, prompt_length, prompt_length
|
||||||
|
)
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
|
@ -82,8 +92,10 @@ class GenerationBenchmark:
|
||||||
trial_time = start_event.elapsed_time(end_event) / 1000
|
trial_time = start_event.elapsed_time(end_event) / 1000
|
||||||
total_time += trial_time
|
total_time += trial_time
|
||||||
|
|
||||||
print(f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
|
print(
|
||||||
f"({prompt_length / trial_time:.1f} tokens/s)")
|
f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
|
||||||
|
f"({prompt_length / trial_time:.1f} tokens/s)"
|
||||||
|
)
|
||||||
|
|
||||||
return BenchmarkResult(
|
return BenchmarkResult(
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
|
|
@ -95,7 +107,7 @@ class GenerationBenchmark:
|
||||||
"prompt_length": prompt_length,
|
"prompt_length": prompt_length,
|
||||||
"dtype": self.dtype,
|
"dtype": self.dtype,
|
||||||
"device": self.device,
|
"device": self.device,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|
@ -111,8 +123,9 @@ class GenerationBenchmark:
|
||||||
total_tokens = batch_size * gen_length * num_trials
|
total_tokens = batch_size * gen_length * num_trials
|
||||||
|
|
||||||
for trial in range(num_trials):
|
for trial in range(num_trials):
|
||||||
|
prompt_ids, gen_ids = self._prepare_inputs(
|
||||||
prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length)
|
batch_size, prompt_length, prompt_length + gen_length
|
||||||
|
)
|
||||||
kv_cache = self._initialize_kv_cache(batch_size)
|
kv_cache = self._initialize_kv_cache(batch_size)
|
||||||
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
|
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
|
||||||
|
|
||||||
|
|
@ -125,8 +138,10 @@ class GenerationBenchmark:
|
||||||
|
|
||||||
current_pos = prompt_length
|
current_pos = prompt_length
|
||||||
for i in range(gen_length):
|
for i in range(gen_length):
|
||||||
input_token = gen_ids[:, i:i+1]
|
input_token = gen_ids[:, i : i + 1]
|
||||||
_ = self.model(input_token, persistent_key_values=kv_cache, start_pos=current_pos)
|
_ = self.model(
|
||||||
|
input_token, persistent_key_values=kv_cache, start_pos=current_pos
|
||||||
|
)
|
||||||
current_pos += 1
|
current_pos += 1
|
||||||
|
|
||||||
end_event.record()
|
end_event.record()
|
||||||
|
|
@ -135,9 +150,10 @@ class GenerationBenchmark:
|
||||||
trial_time = start_event.elapsed_time(end_event) / 1000
|
trial_time = start_event.elapsed_time(end_event) / 1000
|
||||||
total_time += trial_time
|
total_time += trial_time
|
||||||
|
|
||||||
print(f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
print(
|
||||||
f"({gen_length / trial_time:.1f} tokens/s)")
|
f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
||||||
|
f"({gen_length / trial_time:.1f} tokens/s)"
|
||||||
|
)
|
||||||
|
|
||||||
return BenchmarkResult(
|
return BenchmarkResult(
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
|
|
@ -150,7 +166,7 @@ class GenerationBenchmark:
|
||||||
"gen_length": gen_length,
|
"gen_length": gen_length,
|
||||||
"dtype": self.dtype,
|
"dtype": self.dtype,
|
||||||
"device": self.device,
|
"device": self.device,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -164,9 +180,13 @@ def print_benchmark_result(result: BenchmarkResult):
|
||||||
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
|
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
|
||||||
|
|
||||||
if benchmark_type == "prefill":
|
if benchmark_type == "prefill":
|
||||||
print(f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}")
|
print(
|
||||||
|
f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}"
|
||||||
|
)
|
||||||
elif benchmark_type == "decoding":
|
elif benchmark_type == "decoding":
|
||||||
print(f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}")
|
print(
|
||||||
|
f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
|
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
|
|
@ -190,9 +210,12 @@ if __name__ == "__main__":
|
||||||
print("Running Transformer Generation Benchmark")
|
print("Running Transformer Generation Benchmark")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
prefill_result = benchmark.run_prefill_benchmark(batch_size=4, prompt_length=512, num_trials=5)
|
prefill_result = benchmark.run_prefill_benchmark(
|
||||||
|
batch_size=4, prompt_length=512, num_trials=5
|
||||||
|
)
|
||||||
print_benchmark_result(prefill_result)
|
print_benchmark_result(prefill_result)
|
||||||
|
|
||||||
gen_result = benchmark.run_decoding_benchmark(batch_size=4, prompt_length=512, gen_length=128, num_trials=5)
|
gen_result = benchmark.run_decoding_benchmark(
|
||||||
|
batch_size=4, prompt_length=512, gen_length=128, num_trials=5
|
||||||
|
)
|
||||||
print_benchmark_result(gen_result)
|
print_benchmark_result(gen_result)
|
||||||
|
|
||||||
|
|
@ -21,10 +21,10 @@ def processor(
|
||||||
with disable_random_init():
|
with disable_random_init():
|
||||||
param = ModelParameter.load(model_dir)
|
param = ModelParameter.load(model_dir)
|
||||||
|
|
||||||
param.to(device='cuda', dtype=torch.bfloat16)
|
param.to(device="cuda", dtype=torch.bfloat16)
|
||||||
generator = BatchGenerator(param)
|
generator = BatchGenerator(param)
|
||||||
|
|
||||||
with open(input_json_file, "r", encoding='utf-8') as f:
|
with open(input_json_file, "r", encoding="utf-8") as f:
|
||||||
input_data = [json.loads(line) for line in f]
|
input_data = [json.loads(line) for line in f]
|
||||||
|
|
||||||
queries = [item[question_key] for item in input_data]
|
queries = [item[question_key] for item in input_data]
|
||||||
|
|
@ -41,24 +41,60 @@ def processor(
|
||||||
|
|
||||||
responses = generator.generate(request)
|
responses = generator.generate(request)
|
||||||
|
|
||||||
with open(output_json_file, "w", encoding='utf-8') as f:
|
with open(output_json_file, "w", encoding="utf-8") as f:
|
||||||
for query, response in zip(queries, responses):
|
for query, response in zip(queries, responses):
|
||||||
output_item = {question_key: query, response_key: response}
|
output_item = {question_key: query, response_key: response}
|
||||||
f.write(json.dumps(output_item, ensure_ascii=False) + '\n')
|
f.write(json.dumps(output_item, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
|
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
|
||||||
|
|
||||||
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
|
parser.add_argument(
|
||||||
parser.add_argument("--input_json_file", type=str, required=True, help="Path to the input JSONL file.")
|
"--model_dir", type=str, required=True, help="Path to the model directory."
|
||||||
parser.add_argument("--output_json_file", type=str, required=True, help="Path to the output JSONL file.")
|
)
|
||||||
parser.add_argument("--question_key", type=str, default="question", help="Key for the question in the input JSON.")
|
parser.add_argument(
|
||||||
parser.add_argument("--response_key", type=str, default="response", help="Key for the response in the output JSON.")
|
"--input_json_file",
|
||||||
parser.add_argument("--temperature", type=float, default=0.60, help="Temperature for generating responses.")
|
type=str,
|
||||||
parser.add_argument("--top_k", type=int, default=30, help="Top-k value for generating responses.")
|
required=True,
|
||||||
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.")
|
help="Path to the input JSONL file.",
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_json_file",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the output JSONL file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--question_key",
|
||||||
|
type=str,
|
||||||
|
default="question",
|
||||||
|
help="Key for the question in the input JSON.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--response_key",
|
||||||
|
type=str,
|
||||||
|
default="response",
|
||||||
|
help="Key for the response in the output JSON.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.60,
|
||||||
|
help="Temperature for generating responses.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top_k", type=int, default=30, help="Top-k value for generating responses."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top_p",
|
||||||
|
type=float,
|
||||||
|
default=0.95,
|
||||||
|
help="Top-p value for generating responses.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=1, help="Batch size for generating responses."
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,59 +11,56 @@ from khaosz.inference.core import disable_random_init
|
||||||
|
|
||||||
|
|
||||||
def compute_perplexity(
|
def compute_perplexity(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
input_mask: Tensor,
|
input_mask: Tensor,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Compute the perplexity of a batch of input sequences,
|
Compute the perplexity of a batch of input sequences,
|
||||||
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
|
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output = model(input_ids, input_mask)
|
output = model(input_ids, input_mask)
|
||||||
logits = output["logits"]
|
logits = output["logits"]
|
||||||
|
|
||||||
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
||||||
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
|
shifted_input_ids = input_ids[:, 1:] # [batch_size, seq_len-1]
|
||||||
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
|
shifted_mask = input_mask[:, 1:] # [batch_size, seq_len-1]
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
shifted_logits.flatten(0, 1),
|
shifted_logits.flatten(0, 1), shifted_input_ids.flatten(0, 1), reduction="none"
|
||||||
shifted_input_ids.flatten(0, 1),
|
|
||||||
reduction='none'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
|
loss = loss.view(shifted_input_ids.shape) # [batch_size, seq_len-1]
|
||||||
loss = loss * shifted_mask
|
loss = loss * shifted_mask
|
||||||
sentence_loss = (loss).sum(dim=1) / shifted_mask.sum(dim=1)
|
sentence_loss = (loss).sum(dim=1) / shifted_mask.sum(dim=1)
|
||||||
perplexity = torch.exp(sentence_loss) # [batch_size]
|
perplexity = torch.exp(sentence_loss) # [batch_size]
|
||||||
|
|
||||||
return perplexity
|
return perplexity
|
||||||
|
|
||||||
|
|
||||||
def process_file(
|
def process_file(
|
||||||
model_dir: str,
|
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
||||||
input_file: str,
|
|
||||||
output_file: str,
|
|
||||||
batch_size: int,
|
|
||||||
text_key: str
|
|
||||||
):
|
):
|
||||||
with disable_random_init():
|
with disable_random_init():
|
||||||
param = ModelParameter.load(model_dir)
|
param = ModelParameter.load(model_dir)
|
||||||
|
|
||||||
param.to(device='cuda', dtype=torch.bfloat16)
|
param.to(device="cuda", dtype=torch.bfloat16)
|
||||||
model = param.model
|
model = param.model
|
||||||
tokenizer = param.tokenizer
|
tokenizer = param.tokenizer
|
||||||
|
|
||||||
with open(input_file, "r", encoding='utf-8') as f:
|
with open(input_file, "r", encoding="utf-8") as f:
|
||||||
input_data = [json.loads(line) for line in f]
|
input_data = [json.loads(line) for line in f]
|
||||||
|
|
||||||
texts = [item[text_key] for item in input_data]
|
texts = [item[text_key] for item in input_data]
|
||||||
encoded_texts = [tokenizer.encode(text) for text in texts]
|
encoded_texts = [tokenizer.encode(text) for text in texts]
|
||||||
output_data = []
|
output_data = []
|
||||||
|
|
||||||
for i in tqdm(range(0, len(encoded_texts), batch_size), desc="Computing perplexity"):
|
for i in tqdm(
|
||||||
batch_encoded = encoded_texts[i:i + batch_size]
|
range(0, len(encoded_texts), batch_size), desc="Computing perplexity"
|
||||||
batch_texts = texts[i:i + batch_size]
|
):
|
||||||
|
batch_encoded = encoded_texts[i : i + batch_size]
|
||||||
|
batch_texts = texts[i : i + batch_size]
|
||||||
max_len = max(len(seq) for seq in batch_encoded)
|
max_len = max(len(seq) for seq in batch_encoded)
|
||||||
padded_ids = []
|
padded_ids = []
|
||||||
masks = []
|
masks = []
|
||||||
|
|
@ -82,18 +79,31 @@ def process_file(
|
||||||
for text, ppl in zip(batch_texts, perplexity):
|
for text, ppl in zip(batch_texts, perplexity):
|
||||||
output_data.append({text_key: text, "ppl": float(ppl.item())})
|
output_data.append({text_key: text, "ppl": float(ppl.item())})
|
||||||
|
|
||||||
with open(output_file, "w", encoding='utf-8') as f:
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
for item in output_data:
|
for item in output_data:
|
||||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
||||||
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
|
parser.add_argument(
|
||||||
parser.add_argument("--input_file", type=str, required=True, help="Path to the input file.")
|
"--model_dir", type=str, required=True, help="Path to the model directory."
|
||||||
parser.add_argument("--output_file", type=str, required=True, help="Path to the output file.")
|
)
|
||||||
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation.")
|
parser.add_argument(
|
||||||
parser.add_argument("--text_key", type=str, default="text", help="Key for the text field in the input data.")
|
"--input_file", type=str, required=True, help="Path to the input file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_file", type=str, required=True, help="Path to the output file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=4, help="Batch size for evaluation."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_key",
|
||||||
|
type=str,
|
||||||
|
default="text",
|
||||||
|
help="Key for the text field in the input data.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
|
|
||||||
173
tools/train.py
173
tools/train.py
|
|
@ -16,39 +16,129 @@ def parse_args() -> argparse.Namespace:
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||||
|
|
||||||
parser.add_argument("--train_type", type=str, required=True, choices=["seq", "sft", "dpo"], help="Train type.")
|
parser.add_argument(
|
||||||
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
|
"--train_type",
|
||||||
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
|
type=str,
|
||||||
|
required=True,
|
||||||
|
choices=["seq", "sft", "dpo"],
|
||||||
|
help="Train type.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_root_path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the root directory of the dataset.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--param_path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the model parameters or resume checkpoint.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.")
|
parser.add_argument(
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
|
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
||||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.")
|
)
|
||||||
parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.")
|
parser.add_argument(
|
||||||
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
|
"--batch_size", type=int, default=1, help="Batch size for training."
|
||||||
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
|
)
|
||||||
parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.")
|
parser.add_argument(
|
||||||
parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.")
|
"--accumulation_steps",
|
||||||
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
|
type=int,
|
||||||
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
default=1,
|
||||||
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
|
help="Number of iterations between each optimizer step.",
|
||||||
parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory")
|
)
|
||||||
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.")
|
parser.add_argument(
|
||||||
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.")
|
"--warmup_steps",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Number of iters between warnings.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_grad_norm",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Max gradient norm for clipping.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adamw_beta1",
|
||||||
|
type=float,
|
||||||
|
default=0.9,
|
||||||
|
help="Beta values for AdamW optimizer.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adamw_beta2",
|
||||||
|
type=float,
|
||||||
|
default=0.95,
|
||||||
|
help="Beta values for AdamW optimizer.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adamw_weight_decay",
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help="Weight decay for AdamW optimizer.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--random_seed", type=int, default=3407, help="Random seed for reproducibility."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_workers", type=int, default=4, help="Number of workers for data loading."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no_pin_memory",
|
||||||
|
action="store_false",
|
||||||
|
dest="pin_memory",
|
||||||
|
help="Disable pin memory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--window_size",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="the max length of the input sequence.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stride", type=int, default=None, help="the step size of the input sequence."
|
||||||
|
)
|
||||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||||
parser.add_argument("--label_smoothing", type=float, default=0.1, help="cross_entropy function label smoothing parameter")
|
parser.add_argument(
|
||||||
|
"--label_smoothing",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="cross_entropy function label smoothing parameter",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
|
parser.add_argument(
|
||||||
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
|
"--ckpt_interval",
|
||||||
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
type=int,
|
||||||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
default=5000,
|
||||||
|
help="Number of iters between checkpoints.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ckpt_dir",
|
||||||
|
type=str,
|
||||||
|
default="checkpoint",
|
||||||
|
help="Directory to save checkpoints.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--start_epoch", type=int, default=0, help="Start epoch for training."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--start_batch", type=int, default=0, help="Start batch for training."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
||||||
parser.add_argument("--device_type", type=str, default="cuda", help="Device type to use.")
|
parser.add_argument(
|
||||||
|
"--device_type", type=str, default="cuda", help="Device type to use."
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def ddp_wrap(model: nn.Module):
|
def ddp_wrap(model: nn.Module):
|
||||||
local_rank = get_rank()
|
local_rank = get_rank()
|
||||||
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
|
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
|
||||||
|
|
@ -56,16 +146,21 @@ def ddp_wrap(model: nn.Module):
|
||||||
model,
|
model,
|
||||||
device_ids=[local_rank],
|
device_ids=[local_rank],
|
||||||
output_device=local_rank,
|
output_device=local_rank,
|
||||||
find_unused_parameters=False
|
find_unused_parameters=False,
|
||||||
)
|
)
|
||||||
return ddp_model
|
return ddp_model
|
||||||
|
|
||||||
|
|
||||||
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
||||||
return optim.AdamW(model.parameters(), **kwargs)
|
return optim.AdamW(model.parameters(), **kwargs)
|
||||||
|
|
||||||
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
|
|
||||||
|
def create_scheduler(
|
||||||
|
optimizer: optim.Optimizer, **kwargs
|
||||||
|
) -> optim.lr_scheduler.LRScheduler:
|
||||||
return SchedulerFactory.load(optimizer, **kwargs)
|
return SchedulerFactory.load(optimizer, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def prepare_checkpoint(model: nn.Module) -> dict:
|
def prepare_checkpoint(model: nn.Module) -> dict:
|
||||||
return model.module.state_dict()
|
return model.module.state_dict()
|
||||||
|
|
||||||
|
|
@ -81,8 +176,8 @@ def train(
|
||||||
start_batch: int,
|
start_batch: int,
|
||||||
accumulation_steps: int,
|
accumulation_steps: int,
|
||||||
warmup_steps: int,
|
warmup_steps: int,
|
||||||
checkpoint_interval: int,
|
ckpt_interval: int,
|
||||||
checkpoint_dir: str,
|
ckpt_dir: str,
|
||||||
dpo_beta: float,
|
dpo_beta: float,
|
||||||
adamw_beta1: float,
|
adamw_beta1: float,
|
||||||
adamw_beta2: float,
|
adamw_beta2: float,
|
||||||
|
|
@ -107,16 +202,13 @@ def train(
|
||||||
|
|
||||||
model = parameter.model
|
model = parameter.model
|
||||||
|
|
||||||
strategy_kwargs = {
|
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
|
||||||
"dpo_beta": dpo_beta,
|
|
||||||
"label_smoothing": label_smoothing
|
|
||||||
}
|
|
||||||
|
|
||||||
dataset = DatasetLoader.load(
|
dataset = DatasetLoader.load(
|
||||||
train_type=train_type,
|
train_type=train_type,
|
||||||
load_path=data_root_path,
|
load_path=data_root_path,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
stride=stride
|
stride=stride,
|
||||||
)
|
)
|
||||||
|
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(
|
||||||
|
|
@ -124,10 +216,15 @@ def train(
|
||||||
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
|
total_steps=len(dataset) * n_epoch // (batch_size * nprocs),
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer_fn = partial(create_optimizer,
|
optimizer_fn = partial(
|
||||||
**{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay})
|
create_optimizer,
|
||||||
scheduler_fn = partial(create_scheduler,
|
**{
|
||||||
**{"schedule_config": schedule_config})
|
"lr": max_lr,
|
||||||
|
"betas": (adamw_beta1, adamw_beta2),
|
||||||
|
"weight_decay": adamw_weight_decay,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config})
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -135,12 +232,12 @@ def train(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
checkpoint_dir=checkpoint_dir,
|
ckpt_dir=ckpt_dir,
|
||||||
n_epoch=n_epoch,
|
n_epoch=n_epoch,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
start_epoch=start_epoch,
|
start_epoch=start_epoch,
|
||||||
start_batch=start_batch,
|
start_batch=start_batch,
|
||||||
checkpoint_interval=checkpoint_interval,
|
ckpt_interval=ckpt_interval,
|
||||||
accumulation_steps=accumulation_steps,
|
accumulation_steps=accumulation_steps,
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue