diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..6bcea63 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,27 @@ +name: Lint + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install .[dev] + + - name: Check formatting with ruff + run: | + ruff format --check . \ No newline at end of file diff --git a/.github/workflows/spell-check.yml b/.github/workflows/spell-check.yml deleted file mode 100644 index d359cc2..0000000 --- a/.github/workflows/spell-check.yml +++ /dev/null @@ -1,17 +0,0 @@ -name: Spell Check -on: [push, pull_request] - -permissions: - contents: read - -jobs: - spellcheck: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Check spelling in specific files - uses: codespell-project/actions-codespell@v2 - with: - check_filenames: true - only_warn: false - path: "**/*.{md, py}" \ No newline at end of file diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..8ba0596 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,31 @@ +name: Tests + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install .[dev] + + - name: Run tests with pytest + run: | + python -m pytest tests/ -v \ No newline at end of file diff --git a/.gitignore b/.gitignore index c261f0c..eae218e 100644 --- a/.gitignore +++ b/.gitignore @@ -8,5 +8,8 @@ !*.py !*.md !*.png + !LICENSE -!pyproject.toml \ No newline at end of file +!pyproject.toml +!.github/workflows/lint.yml +!.github/workflows/tests.yml \ No newline at end of file diff --git a/README.md b/README.md index eeac26c..e519abe 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,8 @@ python train.py \ --n_epoch=5 \ --batch_size=8 \ --max_lr=2e-4 \ ---checkpoint_interval=10000 \ ---checkpoint_dir=checkpoints +--ckpt_interval=10000 \ +--ckpt_dir=checkpoints ``` **Parameter Explanation:** @@ -67,8 +67,8 @@ python train.py \ - `--accumulation_steps`: Number of batches per training step - `--warmup_steps`: Warmup steps - `--max_lr`: Maximum learning rate (using warmup + cosine decay) -- `--checkpoint_interval`: Checkpoint saving interval -- `--checkpoint_dir`: Checkpoint saving directory +- `--ckpt_interval`: Checkpoint saving interval +- `--ckpt_dir`: Checkpoint saving directory - `--resume_dir`: Resume training from specified path @@ -191,8 +191,8 @@ python train.py \ --n_epoch=5 \ --batch_size=8 \ --max_lr=2e-4 \ ---checkpoint_interval=10000 \ ---checkpoint_dir=checkpoints +--ckpt_interval=10000 \ +--ckpt_dir=checkpoints ``` **参数说明:** @@ -204,8 +204,8 @@ python train.py \ - `--accumulation_steps`: 每个训练步骤的 batch 数量 - `--warmup_steps`: 预热步数(warmup steps) - `--max_lr`: 最大学习率(使用预热 + 余弦衰减) -- `--checkpoint_interval`: 检查点保存间隔 -- `--checkpoint_dir`: 检查点保存目录 +- `--ckpt_interval`: 检查点保存间隔 +- `--ckpt_dir`: 检查点保存目录 - `--resume_dir`: 从指定路径恢复训练 diff --git a/demo/download.py b/demo/download.py index ca641f2..670e9aa 100644 --- a/demo/download.py +++ b/demo/download.py @@ -2,13 +2,12 @@ import os from huggingface_hub import snapshot_download -PROJECT_ROOT = os.path.dirname( - os.path.dirname(os.path.abspath(__file__))) +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if __name__ == "__main__": snapshot_download( repo_id="ViperEk/KHAOSZ", - local_dir=os.path.join(PROJECT_ROOT, "params"), - force_download=True - ) \ No newline at end of file + local_dir=os.path.join(PROJECT_ROOT, "params"), + force_download=True, + ) diff --git a/demo/generate_ar.py b/demo/generate_ar.py index b8c0a26..1a6d078 100644 --- a/demo/generate_ar.py +++ b/demo/generate_ar.py @@ -5,18 +5,18 @@ from khaosz.inference.core import disable_random_init from khaosz.inference.generator import LoopGenerator, GenerationRequest -PROJECT_ROOT = os.path.dirname( - os.path.dirname(os.path.abspath(__file__))) +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + def generate_text(): - + with disable_random_init(): model_dir = os.path.join(PROJECT_ROOT, "params") param = ModelParameter.load(model_dir) - - param.to(device='cuda', dtype=torch.bfloat16) + + param.to(device="cuda", dtype=torch.bfloat16) query = input(">> ") - + request = GenerationRequest( query=query, temperature=0.8, @@ -28,8 +28,9 @@ def generate_text(): ) generator = LoopGenerator(param) response = generator.generate(request) - + print(response) + if __name__ == "__main__": - generate_text() \ No newline at end of file + generate_text() diff --git a/demo/generate_batch.py b/demo/generate_batch.py index 3b39b04..234f5f4 100644 --- a/demo/generate_batch.py +++ b/demo/generate_batch.py @@ -4,18 +4,24 @@ from khaosz.config.param_config import ModelParameter from khaosz.inference.core import disable_random_init from khaosz.inference.generator import BatchGenerator, GenerationRequest -PROJECT_ROOT = os.path.dirname( - os.path.dirname(os.path.abspath(__file__))) +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + def batch_generate(): with disable_random_init(): model_dir = os.path.join(PROJECT_ROOT, "params") param = ModelParameter.load(model_dir) - param.to(device='cuda', dtype=torch.bfloat16) + param.to(device="cuda", dtype=torch.bfloat16) generator = BatchGenerator(param) - inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"] - + inputs = [ + "你好", + "请问什么是人工智能", + "今天天气如何", + "我感到焦虑, 请问我应该怎么办", + "请问什么是显卡", + ] + request = GenerationRequest( query=inputs, temperature=0.8, @@ -26,9 +32,10 @@ def batch_generate(): system_prompt=None, ) responses = generator.generate(request) - + for q, r in zip(inputs, responses): print((q, r)) + if __name__ == "__main__": - batch_generate() \ No newline at end of file + batch_generate() diff --git a/demo/stream_chat.py b/demo/stream_chat.py index abddd88..c2bb322 100644 --- a/demo/stream_chat.py +++ b/demo/stream_chat.py @@ -5,16 +5,16 @@ from khaosz.inference.core import disable_random_init from khaosz.inference.generator import StreamGenerator, GenerationRequest -PROJECT_ROOT = os.path.dirname( - os.path.dirname(os.path.abspath(__file__))) +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + def chat(): - + with disable_random_init(): model_dir = os.path.join(PROJECT_ROOT, "params") param = ModelParameter.load(model_dir) - param.to(device='cuda', dtype=torch.bfloat16) + param.to(device="cuda", dtype=torch.bfloat16) generator = StreamGenerator(param) history = [] @@ -22,7 +22,7 @@ def chat(): query = input(">> ") if query == "!exit": break - + request = GenerationRequest( query=query, temperature=0.8, @@ -32,7 +32,7 @@ def chat(): history=history, system_prompt=None, ) - + response_size = 0 full_response = "" for response in generator.generate(request): @@ -40,10 +40,10 @@ def chat(): print(response[response_size:], end="", flush=True) response_size = len(response) full_response = response - + # After generation, update history history.append((query, full_response.strip())) if __name__ == "__main__": - chat() \ No newline at end of file + chat() diff --git a/khaosz/__init__.py b/khaosz/__init__.py index 5c02a8e..cb88bb8 100644 --- a/khaosz/__init__.py +++ b/khaosz/__init__.py @@ -6,41 +6,30 @@ from khaosz.config import ( TrainConfig, ) from khaosz.model.transformer import Transformer -from khaosz.data import ( - DatasetLoader, - BpeTokenizer -) +from khaosz.data import DatasetLoader, BpeTokenizer from khaosz.inference.generator import ( GenerationRequest, LoopGenerator, StreamGenerator, BatchGenerator, EmbeddingEncoder, - GeneratorFactory -) -from khaosz.trainer import ( - Trainer, - StrategyFactory, - SchedulerFactory + GeneratorFactory, ) +from khaosz.trainer import Trainer, StrategyFactory, SchedulerFactory __all__ = [ "Transformer", - "ModelConfig", "TrainConfig", - "DatasetLoader", "BpeTokenizer", - "GenerationRequest", "LoopGenerator", "StreamGenerator", "BatchGenerator", "EmbeddingEncoder", "GeneratorFactory", - "Trainer", "StrategyFactory", - "SchedulerFactory" -] \ No newline at end of file + "SchedulerFactory", +] diff --git a/khaosz/config/__init__.py b/khaosz/config/__init__.py index c153017..eb2a81b 100644 --- a/khaosz/config/__init__.py +++ b/khaosz/config/__init__.py @@ -1,10 +1,10 @@ from khaosz.config.model_config import ModelConfig from khaosz.config.param_config import BaseModelIO, ModelParameter from khaosz.config.schedule_config import ( - ScheduleConfig, - CosineScheduleConfig, + ScheduleConfig, + CosineScheduleConfig, SGDRScheduleConfig, - ScheduleConfigFactory + ScheduleConfigFactory, ) from khaosz.config.train_config import TrainConfig @@ -13,14 +13,12 @@ __all__ = [ # Base I/O "BaseModelIO", "ModelParameter", - # Model configuration "ModelConfig", "TrainConfig", - # Schedule configuration "ScheduleConfig", "CosineScheduleConfig", "SGDRScheduleConfig", "ScheduleConfigFactory", -] \ No newline at end of file +] diff --git a/khaosz/config/model_config.py b/khaosz/config/model_config.py index 62947ec..06e28e9 100644 --- a/khaosz/config/model_config.py +++ b/khaosz/config/model_config.py @@ -14,30 +14,29 @@ class ModelConfig: norm_eps: Optional[float] = None dim_ffn: Optional[int] = None tie_weight: Optional[bool] = None - + # RoPE max_len: Optional[int] = None rope_theta: Optional[float] = None - + # GQA n_heads: Optional[int] = None n_kv_heads: Optional[int] = None use_qk_norm: Optional[bool] = None use_gated_attention: Optional[bool] = None - - + def load(self, config_path: str) -> Self: config = {} - with open(config_path, 'r') as f: - config.update(json.load(f)) - + with open(config_path, "r") as f: + config.update(json.load(f)) + for key, value in config.items(): if hasattr(self, key): setattr(self, key, value) - + return self - + def save(self, config_path: str): config_dict = {k: v for k, v in asdict(self).items() if v is not None} - with open(config_path, 'w') as f: + with open(config_path, "w") as f: json.dump(config_dict, f, indent=4) diff --git a/khaosz/config/param_config.py b/khaosz/config/param_config.py index 6036a08..2f61098 100644 --- a/khaosz/config/param_config.py +++ b/khaosz/config/param_config.py @@ -9,58 +9,57 @@ from khaosz.data.tokenizer import BpeTokenizer from khaosz.config.model_config import ModelConfig from khaosz.model.transformer import Transformer + @dataclass class BaseModelIO: """Base class for model I/O operations.""" - + model: Optional[nn.Module] = field( - default=None, - metadata={"help": "Transformer model."} + default=None, metadata={"help": "Transformer model."} ) tokenizer: BpeTokenizer = field( - default_factory=BpeTokenizer, - metadata={"help": "Tokenizer for the model."} + default_factory=BpeTokenizer, metadata={"help": "Tokenizer for the model."} ) config: ModelConfig = field( default_factory=ModelConfig, - metadata={"help": "Transformer model configuration."} + metadata={"help": "Transformer model configuration."}, ) - + def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]: """Get standardized file paths for model components.""" dir_path = Path(directory) return { "model": dir_path / "model.safetensors", - "config": dir_path / "config.json", - "tokenizer": dir_path / "tokenizer.json" + "config": dir_path / "config.json", + "tokenizer": dir_path / "tokenizer.json", } - + def save_components(self, save_dir: Union[str, Path]): """Save core model components.""" paths = self._get_file_paths(save_dir) paths["model"].parent.mkdir(parents=True, exist_ok=True) - + if self.model is not None: st.save_file(self.model.state_dict(), str(paths["model"])) self.config.save(str(paths["config"])) self.tokenizer.save(str(paths["tokenizer"])) - + def load_components(self, load_dir: Union[str, Path]) -> Self: """Load core model components.""" paths = self._get_file_paths(load_dir) - + self.config.load(str(paths["config"])) self.tokenizer.load(str(paths["tokenizer"])) - + if self.model is None: self.model = Transformer(self.config) - + if paths["model"].exists(): state_dict = st.load_file(str(paths["model"])) self.model.load_state_dict(state_dict) - + return self - + def to(self, *args, **kwargs) -> "BaseModelIO": """Move model to device.""" if self.model is not None: @@ -71,13 +70,12 @@ class BaseModelIO: @dataclass class ModelParameter(BaseModelIO): """Container for model parameters with serialization capabilities.""" - + @classmethod def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]): instance.save_components(save_dir) - + @classmethod def load(cls, load_dir: Union[str, Path]) -> "ModelParameter": instance = cls() return instance.load_components(load_dir) - diff --git a/khaosz/config/schedule_config.py b/khaosz/config/schedule_config.py index 5a53a33..10fbf45 100644 --- a/khaosz/config/schedule_config.py +++ b/khaosz/config/schedule_config.py @@ -6,35 +6,35 @@ from dataclasses import dataclass, field @dataclass class ScheduleConfig(ABC): """Base configuration class for learning rate schedulers. - + Provides common validation and interface for all schedule types. """ - + schedule_type: str = field( default="cosine", metadata={ - "help": "Type of learning rate schedule.", - "choices": ["cosine", "sgdr"] - } + "help": "Type of learning rate schedule.", + "choices": ["cosine", "sgdr"], + }, ) warmup_steps: int = field( - default=1000, - metadata={"help": "Number of warmup steps."} + default=1000, metadata={"help": "Number of warmup steps."} ) min_rate: float = field( - default=0.05, - metadata={"help": "Minimum learning rate multiplier."} + default=0.05, metadata={"help": "Minimum learning rate multiplier."} ) - + @abstractmethod def get_kwargs(self) -> Dict[str, Any]: """Get configuration kwargs for scheduler creation.""" raise NotImplementedError - + def validate(self) -> None: """Validate configuration parameters.""" if self.warmup_steps < 0: - raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}") + raise ValueError( + f"warmup_steps must be non-negative, got {self.warmup_steps}" + ) if not 0 <= self.min_rate <= 1: raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}") @@ -42,44 +42,43 @@ class ScheduleConfig(ABC): @dataclass class CosineScheduleConfig(ScheduleConfig): """Cosine annealing learning rate schedule configuration.""" - + total_steps: int = field( - default=None, - metadata={"help": "Total training steps for cosine schedule."} + default=None, metadata={"help": "Total training steps for cosine schedule."} ) - + def __post_init__(self) -> None: self.schedule_type = "cosine" self.validate() - + def get_kwargs(self) -> Dict[str, Any]: if self.total_steps is None: raise ValueError("total_steps must be specified for cosine schedule") - + return { "schedule_type": self.schedule_type, "warmup_steps": self.warmup_steps, "lr_decay_steps": self.total_steps - self.warmup_steps, - "min_rate": self.min_rate + "min_rate": self.min_rate, } - + def validate(self) -> None: super().validate() if self.total_steps is not None and self.total_steps <= self.warmup_steps: - raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})") + raise ValueError( + f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})" + ) @dataclass class SGDRScheduleConfig(ScheduleConfig): """Stochastic Gradient Descent with Warm Restarts schedule configuration.""" - + cycle_length: int = field( - default=1000, - metadata={"help": "Length of the first cycle in steps."} + default=1000, metadata={"help": "Length of the first cycle in steps."} ) - t_mult: int = field( - default=2, - metadata={"help": "Multiplier for cycle length growth."} + t_mult: int = field( + default=2, metadata={"help": "Multiplier for cycle length growth."} ) def __post_init__(self) -> None: @@ -92,9 +91,9 @@ class SGDRScheduleConfig(ScheduleConfig): "warmup_steps": self.warmup_steps, "cycle_length": self.cycle_length, "min_rate": self.min_rate, - "t_mult": self.t_mult + "t_mult": self.t_mult, } - + def validate(self) -> None: super().validate() if self.cycle_length <= 0: @@ -105,33 +104,33 @@ class SGDRScheduleConfig(ScheduleConfig): class ScheduleConfigFactory: """Factory class for creating ScheduleConfig instances. - + Supports both direct instantiation and factory creation methods. - + Example usage: # Direct creation config = CosineScheduleConfig(total_steps=10000) - + # Factory method config = ScheduleConfigFactory.create("cosine", total_steps=10000) """ - + CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = { "cosine": CosineScheduleConfig, "sgdr": SGDRScheduleConfig, } - + @classmethod def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig: """Create a schedule config instance. - + Args: schedule_type: Type of schedule ("cosine", "sgdr") **kwargs: Arguments passed to the config constructor - + Returns: ScheduleConfig instance - + Raises: ValueError: If schedule_type is not supported """ @@ -140,11 +139,11 @@ class ScheduleConfigFactory: f"Unknown schedule type: '{schedule_type}'. " f"Supported types: {sorted(cls.CONFIG_MAP.keys())}" ) - + config_cls = cls.CONFIG_MAP[schedule_type] return config_cls(**kwargs) - + @classmethod def available_types(cls) -> list: """Return list of available schedule type names.""" - return list(cls.CONFIG_MAP.keys()) \ No newline at end of file + return list(cls.CONFIG_MAP.keys()) diff --git a/khaosz/config/train_config.py b/khaosz/config/train_config.py index b5e2903..f4b3eb6 100644 --- a/khaosz/config/train_config.py +++ b/khaosz/config/train_config.py @@ -10,127 +10,92 @@ from typing import Callable, List, Optional @dataclass class TrainConfig: # basic setting - model: nn.Module = field( - default=None, - metadata={"help": "Model for training."} - ) - strategy: str = field( - default=None, - metadata={"help": "Training strategy."} - ) - dataset: Dataset = field( - default=None, - metadata={"help": "Dataset for training."} - ) + model: nn.Module = field(default=None, metadata={"help": "Model for training."}) + strategy: str = field(default=None, metadata={"help": "Training strategy."}) + dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."}) optimizer_fn: Callable[[nn.Module], Optimizer] = field( - default=None, - metadata={"help": "Optimizer factory for training."} + default=None, metadata={"help": "Optimizer factory for training."} ) scheduler_fn: Callable[[Optimizer], LRScheduler] = field( - default=None, - metadata={"help": "Scheduler factory for training."} - ) - n_epoch: int = field( - default=1, - metadata={"help": "Number of epochs for training."} - ) - batch_size: int = field( - default=4, - metadata={"help": "Batch size for training."} + default=None, metadata={"help": "Scheduler factory for training."} ) + n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."}) + batch_size: int = field(default=4, metadata={"help": "Batch size for training."}) accumulation_steps: int = field( - default=1, - metadata={"help": "Number of iterations between steps."} + default=1, metadata={"help": "Number of iterations between steps."} ) max_grad_norm: float = field( - default=1.0, - metadata={"help": "Maximum gradient norm."} + default=1.0, metadata={"help": "Maximum gradient norm."} ) - + # checkpoint setting - start_epoch: int = field( - default=0, - metadata={"help": "Start epoch for training."} - ) + start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."}) start_batch: int = field( - default=0, - metadata={"help": "Start batch iteration for training."} + default=0, metadata={"help": "Start batch iteration for training."} ) - checkpoint_dir: str = field( - default="./checkpoint", - metadata={"help": "Checkpoint directory."} + ckpt_dir: str = field( + default="./checkpoint", metadata={"help": "Checkpoint directory."} ) - checkpoint_interval: int = field( - default=5000, - metadata={"help": "Number of iterations between checkpoints."} + ckpt_interval: int = field( + default=5000, metadata={"help": "Number of iterations between checkpoints."} ) - + # dataloader setting - random_seed: int = field( - default=3407, - metadata={"help": "Random seed."} - ) + random_seed: int = field(default=3407, metadata={"help": "Random seed."}) num_workers: int = field( - default=0, - metadata={"help": "Number of workers for dataloader."} + default=0, metadata={"help": "Number of workers for dataloader."} ) prefetch_factor: Optional[int] = field( - default=None, - metadata={"help": "Prefetch factor for dataloader."} + default=None, metadata={"help": "Prefetch factor for dataloader."} ) pin_memory: bool = field( - default=False, - metadata={"help": "Pin memory for dataloader."} + default=False, metadata={"help": "Pin memory for dataloader."} ) - + # distributed training nprocs: int = field( - default=1, - metadata={"help": "Number of processes for distributed training."} + default=1, metadata={"help": "Number of processes for distributed training."} ) backend: str = field( - default="nccl", - metadata={"help": "Distributed training backend."} + default="nccl", metadata={"help": "Distributed training backend."} ) master_addr: str = field( default="localhost", - metadata={"help": "Master address for distributed training."} + metadata={"help": "Master address for distributed training."}, ) master_port: str = field( - default="29500", - metadata={"help": "Master port for distributed training."} + default="29500", metadata={"help": "Master port for distributed training."} ) parallel_wrapper: Optional[Callable] = field( - default=None, - metadata={"help": "Parallel function for training."} + default=None, metadata={"help": "Parallel function for training."} ) state_dict_fn: Optional[Callable] = field( - default=None, - metadata={"help": "Parallel function for state dict saving."} + default=None, metadata={"help": "Parallel function for state dict saving."} ) # others device_ids: Optional[List[int]] = field( - default=None, - metadata={"help": "Device ids for distributed training."} + default=None, metadata={"help": "Device ids for distributed training."} ) device_type: str = field( - default="cuda", - metadata={"help": "Device type for distributed training."} + default="cuda", metadata={"help": "Device type for distributed training."} ) extra_kwargs: dict = field( - default_factory=dict, - metadata={"help": "Other arguments."} + default_factory=dict, metadata={"help": "Other arguments."} ) - + def __post_init__(self): self.validate() - + def validate(self): - required_fields = ["model", "strategy", "dataset", "optimizer_fn", "scheduler_fn"] - + required_fields = [ + "model", + "strategy", + "dataset", + "optimizer_fn", + "scheduler_fn", + ] + for field_name in required_fields: if getattr(self, field_name) is None: raise ValueError(f"{field_name} is required.") - - \ No newline at end of file diff --git a/khaosz/data/__init__.py b/khaosz/data/__init__.py index a10af3e..c924023 100644 --- a/khaosz/data/__init__.py +++ b/khaosz/data/__init__.py @@ -1,12 +1,12 @@ from khaosz.data.dataset import ( - BaseDataset, - SEQDataset, - DPODataset, - SFTDataset, + BaseDataset, + SEQDataset, + DPODataset, + SFTDataset, GRPODataset, MultiSegmentFetcher, DatasetLoader, - DatasetFactory + DatasetFactory, ) from khaosz.data.tokenizer import BpeTokenizer @@ -15,21 +15,17 @@ from khaosz.data.sampler import ResumableDistributedSampler __all__ = [ # Base classes "BaseDataset", - # Dataset implementations "SEQDataset", "SFTDataset", "DPODataset", "GRPODataset", - # Fetchers "MultiSegmentFetcher", - # Factory (DatasetLoader is alias for backward compatibility) "DatasetLoader", "DatasetFactory", - # Tokenizer and sampler "BpeTokenizer", - "ResumableDistributedSampler" -] \ No newline at end of file + "ResumableDistributedSampler", +] diff --git a/khaosz/data/dataset.py b/khaosz/data/dataset.py index fc64ae3..bca8d2c 100644 --- a/khaosz/data/dataset.py +++ b/khaosz/data/dataset.py @@ -12,40 +12,42 @@ from typing import Callable, List, Dict, Literal, Optional, Union class BaseSegmentFetcher: """Fetches data segments across multiple tensor segments. - + Maintains cumulative lengths for efficient range queries across multiple discontinuous segments. """ - + def __init__(self, segments: List[Tensor]): self.segments = segments self.cum_lengths = [] - + total = 0 for seg in segments: total += torch.numel(seg) self.cum_lengths.append(total) - + self.total_length = total - + def __len__(self) -> int: return self.total_length def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: """Fetch data in the range [begin_idx, end_idx). - + Args: begin_idx: Starting index (inclusive) end_idx: Ending index (exclusive) - + Returns: Concatenated tensor of data in the specified range """ - if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length): + if not ( + 0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length + ): raise ValueError("begin_idx or end_idx out of bounds") if begin_idx >= end_idx: return torch.tensor([], dtype=torch.long) - + # Find segment boundaries for the range seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx) seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx) @@ -64,43 +66,44 @@ class BaseSegmentFetcher: class MultiSegmentFetcher: """Manages multiple segment fetchers for different data keys. - + Each key corresponds to a different type of data (e.g., "sequence", "mask"). """ - + def __init__(self, muti_segments: Dict): self.muti_keys = list(muti_segments.keys()) self.muti_fetchers = { - key: BaseSegmentFetcher(segments) - for key, segments in muti_segments.items() + key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items() } - + def __len__(self) -> int: """Returns the minimum length across all fetchers.""" len_list = [len(seg) for seg in self.muti_fetchers.values()] return min(len_list) - - def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict: + + def key_fetch( + self, begin_idx: int, end_idx: int, keys: Union[str, List[str]] + ) -> Dict: """Fetch data for specific keys. - + Args: begin_idx: Starting index end_idx: Ending index keys: Single key or list of keys to fetch - + Returns: Dictionary of tensors if multiple keys, single tensor if one key """ - fetch_dict = {} + fetch_dict = {} keys = [keys] if isinstance(keys, str) else keys - + for key in keys: fetcher = self.muti_fetchers[key] fetch_tensor = fetcher.fetch_data(begin_idx, end_idx) fetch_dict[key] = fetch_tensor return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]] - + def fetch_data(self, begin_idx: int, end_idx: int) -> Dict: """Fetch all keys.""" return self.key_fetch(begin_idx, end_idx, self.muti_keys) @@ -108,10 +111,10 @@ class MultiSegmentFetcher: class BaseDataset(Dataset, ABC): """Abstract base class for all dataset types. - + Implements common functionality for window-based data fetching. """ - + def __init__(self, window_size: int, stride: int): super().__init__() self.segments = {} @@ -122,38 +125,38 @@ class BaseDataset(Dataset, ABC): def load(self, load_path: str): """Load dataset from HDF5 file. - + Args: load_path: Path to the HDF5 data file """ self.segments = load_h5(load_path) self.fetcher = MultiSegmentFetcher(self.segments) self.total_samples = len(self.fetcher) - + def get_index(self, index: int) -> tuple: """Calculate begin and end indices for a sample. - + Args: index: Sample index - + Returns: Tuple of (begin_idx, end_idx) """ assert self.total_samples > self.window_size - + begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size) end_idx = min(begin_idx + self.window_size, self.total_samples - 1) - + return begin_idx, end_idx - + @abstractmethod def __getitem__(self, index: int) -> Dict[str, Tensor]: """Get a single sample by index. - + Must be implemented by subclasses. """ raise NotImplementedError - + def __len__(self) -> int: assert self.total_samples is not None if self.total_samples <= self.window_size: @@ -163,48 +166,50 @@ class BaseDataset(Dataset, ABC): class DatasetFactory: """Factory class for creating dataset instances. - + Supports decorator-based registration for extensible dataset types. All default dataset types (seq, sft, dpo, grpo) are registered automatically when their classes are defined with the decorator. - + Example usage: @DatasetFactory.register("custom") class CustomDataset(BaseDataset): ... - + dataset = DatasetFactory.create("custom", window_size, stride) """ - + SUPPORTED_TYPES = frozenset({"seq", "sft", "dpo", "grpo"}) DATASET_MAP: Dict[str, type] = {} - + @classmethod def register(cls, name: str): """Decorator to register a new dataset class. - + Args: name: Registration name for the dataset type - + Returns: Decorator function that registers the dataset class """ + def decorator(dataset_cls: type) -> type: if not issubclass(dataset_cls, BaseDataset): raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset") cls.DATASET_MAP[name] = dataset_cls return dataset_cls + return decorator - + @classmethod def create(cls, train_type: str, window_size: int, stride: int) -> BaseDataset: """Create a dataset instance. - + Args: train_type: Type of training ("seq", "sft", "dpo", "grpo") window_size: Window size for data sampling stride: Stride between consecutive samples - + Returns: Dataset instance """ @@ -213,36 +218,42 @@ class DatasetFactory: f"Unknown dataset type: '{train_type}'. " f"Supported types: {sorted(cls.SUPPORTED_TYPES)}" ) - + if train_type not in cls.DATASET_MAP: raise NotImplementedError( f"Dataset type '{train_type}' is supported but not yet implemented." ) - + dataset_cls = cls.DATASET_MAP[train_type] return dataset_cls(window_size, stride) - + @classmethod - def load(cls, train_type: str, load_path: str, window_size: int, stride: Optional[int] = None) -> BaseDataset: + def load( + cls, + train_type: str, + load_path: str, + window_size: int, + stride: Optional[int] = None, + ) -> BaseDataset: """Create and load a dataset in one step. - + Args: train_type: Type of training dataset load_path: Path to the data file window_size: Window size for data sampling stride: Stride between consecutive samples (default: same as window_size) - + Returns: Loaded dataset instance """ if stride is None: stride = window_size - + dataset = cls.create(train_type, window_size, stride) dataset.load(load_path) - + return dataset - + @classmethod def available_types(cls) -> list: """Return list of registered dataset type names.""" @@ -256,46 +267,50 @@ class DatasetFactory: @DatasetFactory.register("seq") class SEQDataset(BaseDataset): """Dataset for sequential next-token prediction training.""" - + def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") - + def __getitem__(self, index): begin_idx, end_idx = self.get_index(index) - + x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long) - + return {"input_ids": x, "target_ids": y} @DatasetFactory.register("sft") class SFTDataset(BaseDataset): """Dataset for supervised fine-tuning with loss masking.""" - + def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key) - + def __getitem__(self, index): begin_idx, end_idx = self.get_index(index) - + x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) - y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long) - loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool) - + y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to( + dtype=torch.long + ) + loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to( + dtype=torch.bool + ) + return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask} @DatasetFactory.register("dpo") class DPODataset(BaseDataset): """Dataset for Direct Preference Optimization training.""" - + def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) @@ -304,25 +319,34 @@ class DPODataset(BaseDataset): def __getitem__(self, index: int): begin_idx, end_idx = self.get_index(index) - + chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long) rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long) - chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool) - rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool) + chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to( + dtype=torch.bool + ) + rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to( + dtype=torch.bool + ) - return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask} + return { + "chosen": chosen, + "rejected": rejected, + "chosen_mask": chosen_mask, + "rejected_mask": rejected_mask, + } @DatasetFactory.register("grpo") class GRPODataset(BaseDataset): """Dataset for Group Relative Policy Optimization training.""" - + def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key) - + def __getitem__(self, index: int) -> Dict[str, Tensor]: begin_idx, end_idx = self.get_index(index) @@ -330,8 +354,13 @@ class GRPODataset(BaseDataset): responses = self._fetch_data(begin_idx, end_idx, "responses") masks = self._fetch_data(begin_idx, end_idx, "masks") rewards = self._fetch_data(begin_idx, end_idx, "rewards") - - return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards} + + return { + "prompts": prompts, + "responses": responses, + "masks": masks, + "rewards": rewards, + } # Backward compatibility alias diff --git a/khaosz/data/sampler.py b/khaosz/data/sampler.py index 114afd3..82162c5 100644 --- a/khaosz/data/sampler.py +++ b/khaosz/data/sampler.py @@ -7,45 +7,45 @@ from typing import Optional class ResumableDistributedSampler(Sampler[int]): def __init__( - self, + self, data_source: Dataset, - start_epoch: int=0, - start_iter: int=0, - seed: int=42, - drop_last: bool=False, - shuffle: bool=True, - process_group: Optional[dist.ProcessGroup]=None, + start_epoch: int = 0, + start_iter: int = 0, + seed: int = 42, + drop_last: bool = False, + shuffle: bool = True, + process_group: Optional[dist.ProcessGroup] = None, ): self.epoch = start_epoch self.iter = start_iter self.seed = seed self.num_samples = len(data_source) - + if process_group is not None: # input process group self.rank = dist.get_rank(process_group) self.num_replicas = dist.get_world_size(process_group) - + elif dist.is_available() and dist.is_initialized(): # use default process group process_group = dist.group.WORLD self.rank = dist.get_rank() self.num_replicas = dist.get_world_size() - + else: # single process self.rank = 0 self.num_replicas = 1 - + self.drop_last = drop_last self.shuffle = shuffle - - offset = 0 if drop_last else self.num_replicas - 1 + + offset = 0 if drop_last else self.num_replicas - 1 self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas self.total_size = self.num_samples_per_replica * self.num_replicas - + self._indices = None - + def _get_indices(self): if self.shuffle: generator = torch.Generator() @@ -53,26 +53,26 @@ class ResumableDistributedSampler(Sampler[int]): indices = torch.randperm(self.num_samples, generator=generator).tolist() else: indices = torch.arange(self.num_samples).tolist() - + if not self.drop_last and self.num_samples < self.total_size: padding_size = self.total_size - len(indices) indices += indices[:padding_size] - - local_indices = indices[self.rank:self.total_size:self.num_replicas] - + + local_indices = indices[self.rank : self.total_size : self.num_replicas] + self.iter = self.iter % self.num_samples_per_replica - self._indices = local_indices[self.iter:] - + self._indices = local_indices[self.iter :] + def __iter__(self): if self._indices is None: self._get_indices() - + for i in self._indices: self.iter += 1 yield i - + self.epoch += 1 self._indices = None - + def __len__(self): - return self.num_samples_per_replica \ No newline at end of file + return self.num_samples_per_replica diff --git a/khaosz/data/serialization.py b/khaosz/data/serialization.py index 3258172..c234bfe 100644 --- a/khaosz/data/serialization.py +++ b/khaosz/data/serialization.py @@ -10,24 +10,26 @@ from torch import Tensor from typing import Any, Dict, List from khaosz.parallel.setup import get_rank + def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): os.makedirs(file_path, exist_ok=True) full_file_path = os.path.join(file_path, f"{file_name}.h5") - with h5py.File(full_file_path, 'w') as f: + with h5py.File(full_file_path, "w") as f: for key, tensors in tensor_group.items(): grp = f.create_group(key) for idx, tensor in enumerate(tensors): arr = tensor.cpu().numpy() - grp.create_dataset(f'data_{idx}', data=arr) + grp.create_dataset(f"data_{idx}", data=arr) + def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]: tensor_group: Dict[str, List[Tensor]] = {} root_path = Path(file_path) h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5")) - + for h5_file in h5_files: - with h5py.File(h5_file, 'r') as f: + with h5py.File(h5_file, "r") as f: for key in f.keys(): grp = f[key] dsets = [] @@ -37,7 +39,7 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]: if share_memory: tensor = tensor.share_memory_() dsets.append(tensor) - + if tensor_group.get(key) is None: tensor_group[key] = [] tensor_group[key].extend(dsets) @@ -60,7 +62,7 @@ class Checkpoint: self, save_dir: str, ) -> None: - + save_path = Path(save_dir) save_path.mkdir(parents=True, exist_ok=True) @@ -72,7 +74,7 @@ class Checkpoint: } with open(save_path / "meta.json", "w") as f: json.dump(meta, f, indent=2) - + st.save_file(self.state_dict, save_path / f"state_dict.safetensors") @classmethod @@ -83,7 +85,7 @@ class Checkpoint: rank = get_rank() save_path = Path(save_dir) - + meta = {} if rank == 0: with open(Path(save_dir) / "meta.json", "r") as f: @@ -100,4 +102,4 @@ class Checkpoint: state_dict=state_dict, epoch=meta["epoch"], iteration=meta["iteration"], - ) \ No newline at end of file + ) diff --git a/khaosz/data/tokenizer.py b/khaosz/data/tokenizer.py index 689b5c2..ff850e1 100644 --- a/khaosz/data/tokenizer.py +++ b/khaosz/data/tokenizer.py @@ -9,34 +9,46 @@ class BpeTokenizer: def __init__(self, path=None): self._control_tokens = ["", "", ""] self._special_tokens = ["<|im_start|>", "<|im_end|>"] - + model = BPE() self._tokenizer = Tokenizer(model) - self._tokenizer.normalizer = normalizers.Sequence([ - normalizers.NFC(), - normalizers.Strip() - ]) - - self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence([ - pre_tokenizers.UnicodeScripts(), - pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True) - ]) - + self._tokenizer.normalizer = normalizers.Sequence( + [normalizers.NFC(), normalizers.Strip()] + ) + + self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence( + [ + pre_tokenizers.UnicodeScripts(), + pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True), + ] + ) + self._tokenizer.decoder = decoders.ByteLevel() self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) - + if path is not None: self._tokenizer = Tokenizer.from_file(path) - - def _prepare_trainer(self, vocab_size: int, min_freq: int, reserved_token_size: int, max_token_length=18) -> tuple: + + def _prepare_trainer( + self, + vocab_size: int, + min_freq: int, + reserved_token_size: int, + max_token_length=18, + ) -> tuple: assert reserved_token_size > len(self._special_tokens) - reserved_tokens = [f"<|reserve{i:02d}|>" for i in range(reserved_token_size - len(self._special_tokens))] - detail_vocab_size = vocab_size - (len(reserved_tokens) + len(self._special_tokens)) - + reserved_tokens = [ + f"<|reserve{i:02d}|>" + for i in range(reserved_token_size - len(self._special_tokens)) + ] + detail_vocab_size = vocab_size - ( + len(reserved_tokens) + len(self._special_tokens) + ) + alphabet = pre_tokenizers.ByteLevel.alphabet() min_size = len(alphabet) + len(self._control_tokens) assert detail_vocab_size > min_size - + trainer = BpeTrainer( vocab_size=detail_vocab_size, min_frequency=min_freq, @@ -46,61 +58,74 @@ class BpeTokenizer: initial_alphabet=alphabet, show_progress=True, ) - + return trainer, detail_vocab_size, reserved_tokens def train(self, files, vocab_size, min_freq, reserved_token_size=100): trainer, _, reserved_tokens = self._prepare_trainer( vocab_size=vocab_size, min_freq=min_freq, - reserved_token_size=reserved_token_size + reserved_token_size=reserved_token_size, ) self._tokenizer.train(files=files, trainer=trainer) self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens) - - def train_from_iterator(self, iterator, vocab_size, min_freq, reserved_token_size=100): + + def train_from_iterator( + self, iterator, vocab_size, min_freq, reserved_token_size=100 + ): trainer, _, reserved_tokens = self._prepare_trainer( vocab_size=vocab_size, min_freq=min_freq, - reserved_token_size=reserved_token_size + reserved_token_size=reserved_token_size, ) self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer) self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens) - + def save(self, path): self._tokenizer.save(path) - + def load(self, path): self._tokenizer = Tokenizer.from_file(path) - def encode(self, tokens: Union[str, List[str]], out_ids: bool=True, add_special_tokens: bool=False) -> List: + def encode( + self, + tokens: Union[str, List[str]], + out_ids: bool = True, + add_special_tokens: bool = False, + ) -> List: if isinstance(tokens, str): - encoded: Encoding = self._tokenizer.encode(tokens, add_special_tokens=add_special_tokens) + encoded: Encoding = self._tokenizer.encode( + tokens, add_special_tokens=add_special_tokens + ) return encoded.ids if out_ids else encoded.tokens elif isinstance(tokens, list): - encoded_list: List[Encoding] = self._tokenizer.encode_batch(tokens, add_special_tokens=add_special_tokens) - return [encoded.ids if out_ids else encoded.tokens for encoded in encoded_list] + encoded_list: List[Encoding] = self._tokenizer.encode_batch( + tokens, add_special_tokens=add_special_tokens + ) + return [ + encoded.ids if out_ids else encoded.tokens for encoded in encoded_list + ] - def decode(self, tokens: List[int], skip_special_tokens: bool=True) -> str: + def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str: return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) - + def __len__(self) -> int: return self._tokenizer.get_vocab_size() - + @property def stop_ids(self) -> List[int]: stop_token = self._control_tokens + self._special_tokens stop_ids = [self._tokenizer.token_to_id(token) for token in stop_token] return stop_ids - + @property def bos_id(self) -> int: return self._tokenizer.token_to_id("") - + @property def eos_id(self) -> int: return self._tokenizer.token_to_id("") - + @property def pad_id(self) -> int: - return self._tokenizer.token_to_id("") \ No newline at end of file + return self._tokenizer.token_to_id("") diff --git a/khaosz/inference/__init__.py b/khaosz/inference/__init__.py index c652dc4..03549b5 100644 --- a/khaosz/inference/__init__.py +++ b/khaosz/inference/__init__.py @@ -11,7 +11,7 @@ from khaosz.inference.generator import ( StreamGenerator, BatchGenerator, EmbeddingEncoder, - GeneratorFactory + GeneratorFactory, ) __all__ = [ @@ -19,11 +19,10 @@ __all__ = [ "GeneratorCore", "EmbeddingEncoderCore", "KVCacheManager", - "GenerationRequest", "LoopGenerator", "StreamGenerator", "BatchGenerator", "EmbeddingEncoder", - "GeneratorFactory" -] \ No newline at end of file + "GeneratorFactory", +] diff --git a/khaosz/inference/core.py b/khaosz/inference/core.py index dc37229..77129ce 100644 --- a/khaosz/inference/core.py +++ b/khaosz/inference/core.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from torch import Tensor +from torch import Tensor from contextlib import contextmanager from typing import Any, Callable, List, Tuple, Union, Optional, Self from khaosz.config import ModelParameter, ModelConfig @@ -12,58 +12,61 @@ def apply_sampling_strategies( temperature: float, top_k: int, top_p: float, - filter_value: float = -float("inf") + filter_value: float = -float("inf"), ) -> Tensor: - """ + """ Apply sampling strategies to the logits tensor. - + Args: logits (Tensor): The logits tensor. temperature (float): The temperature parameter. top_k (int): The top-k parameter. top_p (float): The top-p parameter. filter_value (float, optional): The filter value. Defaults to -float("inf"). - + Returns: Tensor: The sampled logits tensor. - + """ - + if temperature != 1.0: logits = logits / temperature - + if top_k > 0: top_k = min(top_k, logits.size(-1)) indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None] logits[indices_to_remove] = filter_value - + if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - + sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 - + indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) indices_to_remove.scatter_( - dim=1, - index=sorted_indices, - src=sorted_indices_to_remove + dim=1, index=sorted_indices, src=sorted_indices_to_remove ) - + logits[indices_to_remove] = filter_value - + return logits @contextmanager def disable_random_init(): init_functions = [ - 'xavier_normal_', 'xavier_uniform_', - 'kaiming_normal_', 'kaiming_uniform_', - 'zeros_', 'ones_', 'constant_', - 'normal_', 'uniform_' + "xavier_normal_", + "xavier_uniform_", + "kaiming_normal_", + "kaiming_uniform_", + "zeros_", + "ones_", + "constant_", + "normal_", + "uniform_", ] original_funcs = {} for name in init_functions: @@ -82,7 +85,7 @@ class GeneratorCore: self.model = parameter.model self.tokenizer = parameter.tokenizer self.config = parameter.config - + def generate_iterator( self, input_ids: Tensor, @@ -91,18 +94,18 @@ class GeneratorCore: top_p: float, attn_mask: Optional[Tensor] = None, kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None, - start_pos: int = 0 - )-> Tuple[Tensor, int]: - + start_pos: int = 0, + ) -> Tuple[Tensor, int]: + with torch.inference_mode(): outputs = self.model(input_ids, attn_mask, kv_caches, start_pos) logits = outputs["logits"][:, -1, :] - cache_increase = input_ids.size(-1) - + cache_increase = input_ids.size(-1) + logits = apply_sampling_strategies(logits, temperature, top_k, top_p) probs = torch.softmax(logits, dim=-1) next_token_id = torch.multinomial(probs, num_samples=1) - + return next_token_id, cache_increase def generate_loop( @@ -115,14 +118,21 @@ class GeneratorCore: attn_mask: Optional[Tensor] = None, kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None, start_pos: int = 0, - callback: Optional[Callable[..., Any]] = None + callback: Optional[Callable[..., Any]] = None, ) -> List[int]: cur_cache_pos = start_pos - + for _ in range(len(ids), self.config.max_len): next_token_id, cache_increase = self.generate_iterator( - input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos) - + input_ids, + temperature, + top_k, + top_p, + attn_mask, + kv_caches, + cur_cache_pos, + ) + input_ids = next_token_id ids.append(next_token_id.item()) cur_cache_pos += cache_increase @@ -132,9 +142,9 @@ class GeneratorCore: if next_token_id.item() in self.tokenizer.stop_ids: break - + return ids - + def to(self, *args, **kargs) -> Self: self.model.to(*args, **kargs) return self @@ -145,32 +155,35 @@ class EmbeddingEncoderCore: self.model = parameter.model self.tokenizer = parameter.tokenizer self.config = parameter.config - + def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]: with_batch = isinstance(sentence, list) ids = self.tokenizer.encode(sentence) batch_ids = ids if with_batch else [ids] max_model_len = self.config.max_len - + all_fragments = [] fragment_origin_idx = [] - + for i, seq in enumerate(batch_ids): if len(seq) > max_model_len: - fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)] + fragments = [ + seq[j : j + max_model_len] + for j in range(0, len(seq), max_model_len) + ] all_fragments.extend(fragments) fragment_origin_idx.extend([i] * len(fragments)) else: all_fragments.append(seq) fragment_origin_idx.append(i) - - #if empty fragments + + # if empty fragments if not all_fragments or not ids: return [] if with_batch else torch.tensor([]) - + device = next(self.model.parameters()).device max_len = min(max(len(seq) for seq in all_fragments), max_model_len) - + padded_ids = [] masks = [] for seq in all_fragments: @@ -179,24 +192,30 @@ class EmbeddingEncoderCore: mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq] padded_ids.append(padded_seq) masks.append(mask) - + input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long) seq_mask = torch.tensor(masks, device=device, dtype=torch.bool) - + with torch.inference_mode(): outputs = self.model(input_tensor, seq_mask)["hidden_states"] # [num_fragments, seq_len, hidden_size] - fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1)) - + fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1)) + sentence_embs: List[Tensor] = [] for i in range(len(batch_ids)): - indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i] + indices = [ + idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i + ] if indices: - sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size] - length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1] - emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size] + sum_frags = torch.sum( + fragment_embs[indices, :, :], dim=1 + ) # [frags, hidden_size] + length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze( + 1 + ) # [frags, 1] + emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size] sentence_embs.append(emb.flatten()) - + if with_batch: return [emb.flatten() for emb in sentence_embs] else: @@ -209,11 +228,11 @@ class EmbeddingEncoderCore: class KVCacheManager: def __init__( - self, - config: ModelConfig, + self, + config: ModelConfig, batch_size: int, - device: torch.device = "cuda", - dtype: torch.dtype = torch.bfloat16 + device: torch.device = "cuda", + dtype: torch.dtype = torch.bfloat16, ): self.batch_size = batch_size self.device = device @@ -221,25 +240,41 @@ class KVCacheManager: self.num_layers = config.n_layers self.max_len = config.max_len self.num_heads = config.n_kv_heads - self.head_dim = config.dim //config.n_heads - + self.head_dim = config.dim // config.n_heads + self._kv_cache: Tuple[Tensor, Tensor] = None self._seq_mask: Tensor = None self._initialize() - + def _initialize(self): k_cache = torch.empty( - (self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim), - device=self.device, dtype=self.dtype + ( + self.batch_size, + self.max_len, + self.num_layers, + self.num_heads, + self.head_dim, + ), + device=self.device, + dtype=self.dtype, ) v_cache = torch.empty( - (self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim), - device=self.device, dtype=self.dtype + ( + self.batch_size, + self.max_len, + self.num_layers, + self.num_heads, + self.head_dim, + ), + device=self.device, + dtype=self.dtype, ) self._kv_cache = (k_cache, v_cache) - self._seq_mask = torch.ones((self.batch_size, self.max_len), device=self.device, dtype=torch.bool) + self._seq_mask = torch.ones( + (self.batch_size, self.max_len), device=self.device, dtype=torch.bool + ) - def update(self, active_mask: Tensor): + def update(self, active_mask: Tensor): k_cache, v_cache = self._kv_cache self._kv_cache = (k_cache[active_mask], v_cache[active_mask]) self._seq_mask = self._seq_mask[active_mask] @@ -250,14 +285,14 @@ class KVCacheManager: self._seq_mask = None else: self._initialize() - + def set_seq_mask(self, input_ids: Tensor, pad_id: int): batch_size, seq_len = input_ids.shape - bool_mask = (input_ids != pad_id) - self._seq_mask[: batch_size, : seq_len] = bool_mask + bool_mask = input_ids != pad_id + self._seq_mask[:batch_size, :seq_len] = bool_mask def get_kvcache(self) -> Tuple[Tensor, Tensor]: return self._kv_cache - - def get_seq_mask(self) -> Tensor: - return self._seq_mask \ No newline at end of file + + def get_seq_mask(self) -> Tensor: + return self._seq_mask diff --git a/khaosz/inference/generator.py b/khaosz/inference/generator.py index e047899..78a4357 100644 --- a/khaosz/inference/generator.py +++ b/khaosz/inference/generator.py @@ -1,6 +1,6 @@ import torch from dataclasses import dataclass -from torch import Tensor +from torch import Tensor from typing import List, Tuple, Union, Optional, Generator from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager from khaosz.config.param_config import ModelParameter @@ -8,10 +8,11 @@ from khaosz.config.param_config import ModelParameter HistoryType = List[Tuple[str, str]] + def build_prompt( query: str, system_prompt: Optional[str] = None, - history: Optional[HistoryType] = None + history: Optional[HistoryType] = None, ) -> str: """ Build prompt in ChatML format for query and history. @@ -42,17 +43,17 @@ def build_prompt( def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]: - """ + """ Pad a list of sequences to a fixed length. - + Args: ids_list (List[List[int]]): A list of sequences. max_ids_len (int): The maximum length of sequences. pad_id (int): The id to pad sequences. - + Returns: List[List[int]]: A list of padded sequences. - + """ max_ids_len = max(len(ids) for ids in ids_list) new_ids_list = [] @@ -60,7 +61,7 @@ def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int] pad_len = max_ids_len - len(ids) padded_seq = [pad_id] * pad_len + ids new_ids_list.append(padded_seq) - + return new_ids_list, max_ids_len @@ -68,7 +69,7 @@ def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int] class GenerationRequest: """ Request parameters for text generation. - + Attributes: top_k: Top-k sampling parameter. top_p: Top-p (nucleus) sampling parameter. @@ -79,6 +80,7 @@ class GenerationRequest: system_prompt: System prompt for the conversation. stream: Whether to use streaming generation. """ + top_k: int top_p: float temperature: float @@ -101,63 +103,66 @@ class GenerationRequest: class LoopGenerator(GeneratorCore): def __init__(self, parameter: ModelParameter): super().__init__(parameter) - + def generate(self, request: GenerationRequest) -> str: device = next(self.model.parameters()).device cache_manager = KVCacheManager(self.config, 1, device=device) - + prompt = build_prompt(request.query, request.history) ids = self.tokenizer.encode(prompt) input_ids = torch.tensor([ids], device=device, dtype=torch.long) - + start_cache_pos = len(ids) self.model.eval() kv_caches = cache_manager.get_kvcache() - + ids = self.generate_loop( - input_ids, - ids, - request.temperature, - request.top_k, - request.top_p, - kv_caches=kv_caches, + input_ids, + ids, + request.temperature, + request.top_k, + request.top_p, + kv_caches=kv_caches, ) response = self.tokenizer.decode(ids[start_cache_pos:]) - + return response class StreamGenerator(GeneratorCore): def __init__(self, parameter: ModelParameter): super().__init__(parameter) - - def generate(self, request: GenerationRequest) -> Generator[str, None, None]: + + def generate(self, request: GenerationRequest) -> Generator[str, None, None]: device = next(self.model.parameters()).device cache_manager = KVCacheManager(self.config, 1, device=device) prompt = build_prompt(request.query, request.history) ids = self.tokenizer.encode(prompt) input_ids = torch.tensor([ids], device=device, dtype=torch.long) - + start_cache_pos = len(ids) cur_cache_pos = 0 self.model.eval() kv_caches = cache_manager.get_kvcache() - + for _ in range(len(ids), self.config.max_len): next_token_id, cache_increase = self.generate_iterator( - input_ids, request.temperature, request.top_k, request.top_p, - kv_caches=kv_caches, - start_pos=cur_cache_pos + input_ids, + request.temperature, + request.top_k, + request.top_p, + kv_caches=kv_caches, + start_pos=cur_cache_pos, ) - + input_ids = next_token_id ids.append(next_token_id.item()) cur_cache_pos += cache_increase - + response = self.tokenizer.decode(ids[start_cache_pos:]) yield response - + if next_token_id.item() in self.tokenizer.stop_ids: yield response + "\n" break @@ -166,131 +171,140 @@ class StreamGenerator(GeneratorCore): class BatchGenerator(GeneratorCore): def __init__(self, parameter: ModelParameter): super().__init__(parameter) - + def generate(self, request: GenerationRequest) -> List[str]: batch_size = len(request.query) if request.history is None: request.history = [[] for _ in range(batch_size)] - - prompts = [build_prompt(query, history) for query, history in zip(request.query, request.history)] - + + prompts = [ + build_prompt(query, history) + for query, history in zip(request.query, request.history) + ] + ids_list = [self.tokenizer.encode(prompt) for prompt in prompts] ids_list, max_ids_len = pad_sequence(ids_list, self.tokenizer.pad_id) - + device = next(self.model.parameters()).device cache_manager = KVCacheManager(self.config, batch_size, device=device) - + input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long) cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id) activate_task_mask = [True] * batch_size - + start_cache_pos = max_ids_len cur_cache_pos = 0 - + while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0: kv_caches = cache_manager.get_kvcache() - attn_mask =cache_manager.get_seq_mask() - + attn_mask = cache_manager.get_seq_mask() + next_token_id, cache_increase = self.generate_iterator( - input_tensor, request.temperature, request.top_k, request.top_p, - attn_mask=attn_mask, - kv_caches=kv_caches, - start_pos=cur_cache_pos + input_tensor, + request.temperature, + request.top_k, + request.top_p, + attn_mask=attn_mask, + kv_caches=kv_caches, + start_pos=cur_cache_pos, ) - + cur_cache_pos += cache_increase active_mask = [] c_ids = 0 - + for i in range(batch_size): if activate_task_mask[i]: token = next_token_id[c_ids, :].item() ids_list[i].append(token) c_ids += 1 - + is_active = not token in self.tokenizer.stop_ids - activate_task_mask[i] = is_active + activate_task_mask[i] = is_active active_mask.append(is_active) - + active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool) cache_manager.update(active_mask) input_tensor = next_token_id[active_mask, :] max_ids_len += 1 - + responses = [str()] * batch_size for i in range(batch_size): responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:]) request.history[i].append((request.query[i], responses[i])) - + return responses class EmbeddingEncoder(EmbeddingEncoderCore): def __init__(self, parameter: ModelParameter): super().__init__(parameter) - + def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]: return super().encode(sentence) class GeneratorFactory: """Factory class for creating generator instances. - + Provides smart generator selection based on request characteristics: - Streaming: Use StreamGenerator for streaming output - Batch: Use BatchGenerator when query is a list - Single: Use LoopGenerator for single query non-streaming - + Example usage: generator = GeneratorFactory.create_generator(parameter, request) result = generator.generate(request) """ - + @staticmethod - def create_generator(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore: + def create_generator( + parameter: ModelParameter, request: GenerationRequest + ) -> GeneratorCore: """Create a generator based on request characteristics. - + Args: parameter: Model parameters containing model, tokenizer, config request: Generation request with query, options, etc. - + Returns: Appropriate GeneratorCore subclass instance """ # Streaming generation: check stream field first if request.stream: return StreamGenerator(parameter) - + # Batch generation: query is a list of strings if isinstance(request.query, list): return BatchGenerator(parameter) - + # Default: single query non-streaming return LoopGenerator(parameter) - + @staticmethod def create_encoder(parameter: ModelParameter) -> EmbeddingEncoderCore: """Create an embedding encoder instance. - + Args: parameter: Model parameters - + Returns: EmbeddingEncoderCore instance """ return EmbeddingEncoder(parameter) - + @classmethod - def create(cls, parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore: + def create( + cls, parameter: ModelParameter, request: GenerationRequest + ) -> GeneratorCore: """Convenience method that delegates to create_generator. - + Args: parameter: Model parameters request: Generation request - + Returns: Generator instance """ return cls.create_generator(parameter, request) - \ No newline at end of file diff --git a/khaosz/model/__init__.py b/khaosz/model/__init__.py index c70b05c..4f71990 100644 --- a/khaosz/model/__init__.py +++ b/khaosz/model/__init__.py @@ -1,17 +1,10 @@ -from khaosz.model.module import ( +from khaosz.model.module import ( Linear, - RMSNorm, + RMSNorm, MLP, GQA, DecoderBlock, ) from khaosz.model.transformer import Transformer -__all__ = [ - "Linear", - "RMSNorm", - "MLP", - "GQA", - "DecoderBlock", - "Transformer" -] \ No newline at end of file +__all__ = ["Linear", "RMSNorm", "MLP", "GQA", "DecoderBlock", "Transformer"] diff --git a/khaosz/model/module.py b/khaosz/model/module.py index ca14b40..e29eb5f 100644 --- a/khaosz/model/module.py +++ b/khaosz/model/module.py @@ -7,7 +7,7 @@ from typing import Optional, Tuple def repeat_kv(x: Tensor, n_rep: int) -> Tensor: - """ + """ Repeat k times along the dimension for attention heads. Args: x (Tensor): The input tensor. @@ -15,7 +15,7 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor: Returns: Tensor: The repeated tensor. """ - + bs, slen, n_heads, head_dim = x.shape if n_rep == 1: return x @@ -25,12 +25,13 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor: .reshape(bs, slen, n_heads * n_rep, head_dim) ) + def get_rotary_emb( - dim: int, - max_len: int, - base: float = 10000, - ) -> Tuple[Tensor, Tensor]: - """ + dim: int, + max_len: int, + base: float = 10000, +) -> Tuple[Tensor, Tensor]: + """ Get the rotary embedding for the given dimension and maximum length. Args: dim (int): The dimension of the input. @@ -46,6 +47,7 @@ def get_rotary_emb( return torch.cos(freqs).float(), torch.sin(freqs).float() + def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor: """ Apply rotary embedding to the input tensor using cos/sin form. @@ -55,49 +57,49 @@ def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tens Returns: Tensor: The output tensor (rotated, same shape as input). """ - + dtype = x.dtype cos, sin = rotary_emb - + cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2] sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2] - + x_real = x[..., 0::2] # [batch, seq_len, dim//2] x_imag = x[..., 1::2] # [batch, seq_len, dim//2] - + x_real_rot = x_real * cos - x_imag * sin x_imag_rot = x_real * sin + x_imag * cos - + x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2] - x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim] - + x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim] + return x_out.to(dtype) class RotaryEmbedding(nn.Module): - def __init__(self, dim: int, max_len: int, base: int=10000): + def __init__(self, dim: int, max_len: int, base: int = 10000): super().__init__() self.dim = dim self.max_len = max_len self.base = base self.max_len_cached = None self._set_rotary_buffer(self.max_len) - + def _set_rotary_buffer(self, max_len: int): cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base) self.register_buffer("cos_cached", cos_cached, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False) self.max_len_cached = max_len - - def forward(self, x: Tensor, start_pos: int=0) -> Tuple[Tensor, Tensor]: + + def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]: seq_len = x.size(1) - + if self.max_len_cached < seq_len + start_pos: self._set_rotary_buffer(seq_len + start_pos) - + cos = self.cos_cached[start_pos : start_pos + seq_len] sin = self.sin_cached[start_pos : start_pos + seq_len] - + return (cos, sin) @@ -115,43 +117,42 @@ class RMSNorm(nn.Module): def __init__(self, dim, norm_eps): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) - self.normalized_shape = (dim, ) + self.normalized_shape = (dim,) self.norm_eps = norm_eps def forward(self, x: Tensor) -> Tensor: - rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps) + rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps) return rms.to(x.dtype) - - + + class MLP(nn.Module): def __init__(self, dim: int, dim_feed_forward: int): super().__init__() self.up = Linear(dim, dim_feed_forward) self.gate = Linear(dim, dim_feed_forward) self.down = Linear(dim_feed_forward, dim) - + def forward(self, x: Tensor) -> Tensor: gated = self.up(x) * F.silu(self.gate(x)) out = self.down(gated) return out - class GQA(nn.Module): def __init__( - self, - dim: int, - n_heads: int, - n_kv_heads: int, + self, + dim: int, + n_heads: int, + n_kv_heads: int, use_qk_norm: bool, norm_eps: float, use_gated_attention: bool, - layer_id: int + layer_id: int, ): super().__init__() assert dim % n_heads == 0 assert n_heads % n_kv_heads == 0 - + self.head_dim = dim // n_heads self.layer_id = layer_id self.dim = dim @@ -165,11 +166,11 @@ class GQA(nn.Module): self.k_proj = Linear(dim, n_kv_heads * self.head_dim) self.v_proj = Linear(dim, n_kv_heads * self.head_dim) self.o_proj = Linear(dim, dim) - + if self.use_qk_norm: self.q_norm = RMSNorm(self.head_dim, norm_eps) self.k_norm = RMSNorm(self.head_dim, norm_eps) - + if self.use_gated_attention: self.gate = Linear(dim, dim) @@ -177,14 +178,14 @@ class GQA(nn.Module): batch_size, seq_len, _ = x.shape x = x.reshape(batch_size, seq_len, n_heads, self.head_dim) return x - + def forward( self, - x: Tensor, - rotary_emb: Tuple[Tensor, Tensor], + x: Tensor, + rotary_emb: Tuple[Tensor, Tensor], mask: Tensor = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None, - start_pos: int = 0 + start_pos: int = 0, ) -> Tensor: bsz, seq_len, _ = x.size() is_causal = mask is None @@ -194,31 +195,36 @@ class GQA(nn.Module): k = self._split_heads(self.k_proj(x), self.n_kv_heads) v = self._split_heads(self.v_proj(x), self.n_kv_heads) q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb) - + if self.use_qk_norm: q, k = self.q_norm(q), self.k_norm(k) - + if kv_cache is not None: k_cache, v_cache = kv_cache - + # copy to cache - k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k - v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v - + k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k + v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v + # get cache - k = k_cache[:bsz, :start_pos + seq_len, self.layer_id] - v = v_cache[:bsz, :start_pos + seq_len, self.layer_id] - + k = k_cache[:bsz, : start_pos + seq_len, self.layer_id] + v = v_cache[:bsz, : start_pos + seq_len, self.layer_id] + k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) - + # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) # (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim) - sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal).permute(0, 2, 1, 3).contiguous().flatten(2) - + sdqa_out = ( + F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal) + .permute(0, 2, 1, 3) + .contiguous() + .flatten(2) + ) + if self.use_gated_attention: sdqa_out = sdqa_out * F.sigmoid(self.gate(x)) - + out = self.o_proj(sdqa_out) return out @@ -227,15 +233,15 @@ class GQA(nn.Module): class MLA(nn.Module): def __init__( self, - dim: int, - n_heads: int, + dim: int, + n_heads: int, n_kv_heads: int, kv_lora_rank: int, qk_nope_head_dim: int, qk_rope_head_dim: int, norm_eps: float, use_gated_attention: bool, - layer_id: int + layer_id: int, ): super().__init__() self.dim = dim @@ -252,45 +258,46 @@ class MLA(nn.Module): self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False) self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False) self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps) - + # KV (k_nope, k_rope, v) self.kv_b_proj = Linear( - kv_lora_rank, + kv_lora_rank, n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim), ) - + self.o_proj = Linear(dim, dim, bias=False) - + if use_gated_attention: self.gate = Linear(dim, dim, bias=False) - + def forward( self, x: Tensor, rotary_emb: Tuple[Tensor, Tensor], mask: Tensor = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None, - start_pos: int = 0 + start_pos: int = 0, ) -> Tensor: bsz, seq_len, _ = x.size() is_causal = mask is None - + q = self.q_proj(x) q = q.view(bsz, seq_len, self.n_heads, self.head_dim) - + kv_compressed = self.kv_a_proj(x) kv_compressed = self.kv_norm(kv_compressed) - + kv = self.kv_b_proj(kv_compressed) kv = kv.view(bsz, seq_len, self.n_kv_heads, -1) - + k_nope, k_rope, v = torch.split( - kv, - [self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], - dim=-1 + kv, [self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], dim=-1 + ) + + q_nope, q_rope = ( + q[..., : self.qk_nope_head_dim], + q[..., self.qk_rope_head_dim :], ) - - q_nope, q_rope = q[..., :self.qk_nope_head_dim], q[..., self.qk_rope_head_dim:] q_rope = apply_rotary_emb(q_rope, rotary_emb) k_rope = apply_rotary_emb(k_rope, rotary_emb) @@ -299,41 +306,48 @@ class MLA(nn.Module): if kv_cache is not None: k_cache, v_cache = kv_cache - k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k - v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v - k = k_cache[:bsz, :start_pos + seq_len, self.layer_id] - v = v_cache[:bsz, :start_pos + seq_len, self.layer_id] - + k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k + v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v + k = k_cache[:bsz, : start_pos + seq_len, self.layer_id] + v = v_cache[:bsz, : start_pos + seq_len, self.layer_id] + q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal) attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2) - + if self.use_gated_attention: attn_out = attn_out * F.sigmoid(self.gate(x)) out = self.o_proj(attn_out) - + return out class DecoderBlock(nn.Module): def __init__( - self, - dim: int, - n_heads: int, - dim_ffn: int, - n_kv_heads: int, - norm_eps: int, - use_qk_norm: bool, - use_gated_attention: bool, - layer_id: int + self, + dim: int, + n_heads: int, + dim_ffn: int, + n_kv_heads: int, + norm_eps: int, + use_qk_norm: bool, + use_gated_attention: bool, + layer_id: int, ): super().__init__() - self.attention = GQA(dim, n_heads, n_kv_heads, - use_qk_norm, norm_eps, use_gated_attention, layer_id) + self.attention = GQA( + dim, + n_heads, + n_kv_heads, + use_qk_norm, + norm_eps, + use_gated_attention, + layer_id, + ) self.input_norm = RMSNorm(dim, norm_eps) self.mlp = MLP(dim, dim_ffn) self.post_attention_norm = RMSNorm(dim, norm_eps) @@ -341,24 +355,20 @@ class DecoderBlock(nn.Module): def forward( self, x: Tensor, - rotary_emb: Tuple[Tensor, Tensor], + rotary_emb: Tuple[Tensor, Tensor], attention_mask: Optional[Tensor] = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None, - start_pos: int = 0 + start_pos: int = 0, ) -> Tensor: # attention attn_output = self.attention( - self.input_norm(x), - rotary_emb, - attention_mask, - kv_cache, - start_pos + self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos ) x = attn_output + x - + # feed forward x = self.mlp(self.post_attention_norm(x)) + x - + return x @@ -366,6 +376,6 @@ class Embedding(nn.Module): def __init__(self, vocab_size: int, embedding_dim: int): super().__init__() self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim))) - + def forward(self, x: Tensor) -> Tensor: - return F.embedding(x, self.weight) \ No newline at end of file + return F.embedding(x, self.weight) diff --git a/khaosz/model/transformer.py b/khaosz/model/transformer.py index 840bf6b..a90fade 100644 --- a/khaosz/model/transformer.py +++ b/khaosz/model/transformer.py @@ -4,15 +4,21 @@ import torch.nn as nn from torch import Tensor from typing import Any, Mapping, Optional, Tuple from khaosz.config.model_config import ModelConfig -from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, RotaryEmbedding +from khaosz.model.module import ( + Embedding, + DecoderBlock, + Linear, + RMSNorm, + RotaryEmbedding, +) def process_attention_mask( - seq_mask: Tensor, - input_tensor: Tensor, - start_pos: int = 0, - is_causal: bool = False, - ) -> Tensor: + seq_mask: Tensor, + input_tensor: Tensor, + start_pos: int = 0, + is_causal: bool = False, +) -> Tensor: """ Create attention mask for GQA Args: @@ -26,32 +32,36 @@ def process_attention_mask( device = input_tensor.device dtype = input_tensor.dtype seq_len = input_tensor.size(1) - + if seq_mask is None: if start_pos != 0: # for single prompt chat seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) else: return None - + if seq_mask.dim() > 2: # shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos) # if ndim > 2, it's 4D tensor return seq_mask - + batch_size = seq_mask.size(0) - seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool) + seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool) # (bsz, start_pos + seq_len) - expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len) + expanded_mask = seq_mask.unsqueeze(1).expand( + batch_size, seq_len, start_pos + seq_len + ) # (bsz, seq_len, start_pos + seq_len) - + if is_causal: expanded_mask = torch.tril(expanded_mask, diagonal=start_pos) - + attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device) - attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1) + attention_mask = attention_mask.masked_fill_( + ~expanded_mask, -torch.finfo(dtype).max / 2 + ).unsqueeze(1) # (bsz, 1, seq_len, seq_len + start_pos) - + return attention_mask @@ -59,26 +69,38 @@ class Transformer(nn.Module): def __init__(self, config: ModelConfig): super().__init__() self.config = config - self.rotary_embeding = RotaryEmbedding(config.dim // config.n_heads, config.max_len) + self.rotary_embeding = RotaryEmbedding( + config.dim // config.n_heads, config.max_len + ) self.embed_tokens = Embedding(config.vocab_size, config.dim) - - self.layers = nn.ModuleList([ - DecoderBlock(config.dim, config.n_heads, config.dim_ffn, config.n_kv_heads, - config.norm_eps, config.use_qk_norm, config.use_gated_attention, layer_id) - for layer_id in range(config.n_layers) - ]) - + + self.layers = nn.ModuleList( + [ + DecoderBlock( + config.dim, + config.n_heads, + config.dim_ffn, + config.n_kv_heads, + config.norm_eps, + config.use_qk_norm, + config.use_gated_attention, + layer_id, + ) + for layer_id in range(config.n_layers) + ] + ) + self.norm = RMSNorm(config.dim, config.norm_eps) self.lm_head = Linear(config.dim, config.vocab_size) - + if self.config.tie_weight == True: self.lm_head.weight = self.embed_tokens.weight self._init_parameters() - + def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False): - lm_head_key = 'lm_head.weight' - embed_key = 'embed_tokens.weight' + lm_head_key = "lm_head.weight" + embed_key = "embed_tokens.weight" if self.config.tie_weight == True: # same tensor @@ -87,48 +109,44 @@ class Transformer(nn.Module): if lm_head_key not in state_dict and embed_key in state_dict: # use clone to avoid sharing the same tensor state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) - + return super().load_state_dict(state_dict, strict, assign) - - def state_dict(self, destination=None, prefix='', keep_vars=False): - state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - + + def state_dict(self, destination=None, prefix="", keep_vars=False): + state_dict = super().state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + if self.config.tie_weight == True: - lm_head_key = prefix + 'lm_head.weight' + lm_head_key = prefix + "lm_head.weight" if lm_head_key in state_dict: del state_dict[lm_head_key] - + return state_dict - + def _init_parameters(self): for param in self.parameters(): if param.dim() > 1: nn.init.normal_(param, mean=0.0, std=0.006) - + def forward( - self, - input_ids: Tensor, - input_mask: Optional[Tensor]=None, - persistent_key_values: Optional[Tuple[Tensor, Tensor]]=None, - start_pos: int = 0 + self, + input_ids: Tensor, + input_mask: Optional[Tensor] = None, + persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None, + start_pos: int = 0, ) -> Tensor: assert input_ids.ndim == 2 - + x = self.embed_tokens(input_ids) rotary_emb = self.rotary_embeding(x, start_pos) - - attn_mask = process_attention_mask( - input_mask, x, start_pos, is_causal=True - ) - + + attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True) + for layer in self.layers: x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos) - - hidden_states = self.norm(x) + + hidden_states = self.norm(x) logits = self.lm_head(hidden_states) - - return { - "logits": logits, - "hidden_states": hidden_states - } - \ No newline at end of file + + return {"logits": logits, "hidden_states": hidden_states} diff --git a/khaosz/parallel/__init__.py b/khaosz/parallel/__init__.py index 00254d4..896be8a 100644 --- a/khaosz/parallel/__init__.py +++ b/khaosz/parallel/__init__.py @@ -1,27 +1,21 @@ from khaosz.parallel.setup import ( - get_world_size, + get_world_size, get_rank, get_current_device, - only_on_rank, - setup_parallel, - spawn_parallel_fn + setup_parallel, + spawn_parallel_fn, ) -from khaosz.parallel.module import ( - RowParallelLinear, - ColumnParallelLinear -) +from khaosz.parallel.module import RowParallelLinear, ColumnParallelLinear __all__ = [ "get_world_size", "get_rank", "get_current_device", - "only_on_rank", "setup_parallel", "spawn_parallel_fn", - "RowParallelLinear", - "ColumnParallelLinear" + "ColumnParallelLinear", ] diff --git a/khaosz/parallel/module.py b/khaosz/parallel/module.py index 5f41434..3dd94e9 100644 --- a/khaosz/parallel/module.py +++ b/khaosz/parallel/module.py @@ -17,91 +17,99 @@ class ParallelModel(nn.Module): class RowParallelLinear(ParallelModel): def __init__( - self, + self, process_group: dist.ProcessGroup, - in_features: int, - out_features: int, + in_features: int, + out_features: int, bias: bool = True, - reduce_results: bool = True + reduce_results: bool = True, ): super().__init__(process_group) - + self.in_features = in_features self.out_features = out_features self.in_features_per_rank = in_features // self.world_size self.reduce_results = reduce_results - + if in_features % self.world_size != 0: - raise ValueError(f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}") - + raise ValueError( + f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}" + ) + self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank)) self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None - + def forward(self, input: Tensor) -> Tensor: output = F.linear(input, self.weight) - + if self.reduce_results: dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group) - + if self.bias is not None: output += self.bias - + return output - - def load_state_dict(self, state_dict: Dict[str, Tensor]): - full_weight = state_dict.get('weight') - full_bias = state_dict.get('bias') - + + def load_state_dict(self, state_dict: Dict[str, Tensor]): + full_weight = state_dict.get("weight") + full_bias = state_dict.get("bias") + start_idx = self.rank * self.in_features_per_rank end_idx = start_idx + self.in_features_per_rank weight_slice = full_weight[:, start_idx:end_idx] self.weight.data.copy_(weight_slice) - + if self.bias is not None: self.bias.data.copy_(full_bias) class ColumnParallelLinear(ParallelModel): def __init__( - self, + self, process_group: dist.ProcessGroup, - in_features: int, - out_features: int, + in_features: int, + out_features: int, bias: bool = True, - gather_results: bool = True + gather_results: bool = True, ): super().__init__(process_group) - + self.in_features = in_features self.out_features = out_features self.out_features_per_rank = out_features // self.world_size self.gather_results = gather_results if out_features % self.world_size != 0: - raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}") + raise ValueError( + f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}" + ) + + self.weight = nn.Parameter( + torch.empty(self.out_features_per_rank, self.in_features) + ) + self.bias = ( + nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None + ) - self.weight = nn.Parameter(torch.empty(self.out_features_per_rank, self.in_features)) - self.bias = nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None - def forward(self, input: Tensor) -> Tensor: output = F.linear(input, self.weight, self.bias) - + if self.gather_results: output_list = [torch.empty_like(output) for _ in range(self.world_size)] dist.all_gather(output_list, output, group=self.process_group) output = torch.cat(output_list, dim=-1) - + return output - - def load_state_dict(self, state_dict: Dict[str, Tensor]): - full_weight = state_dict.get('weight') - full_bias = state_dict.get('bias') - + + def load_state_dict(self, state_dict: Dict[str, Tensor]): + full_weight = state_dict.get("weight") + full_bias = state_dict.get("bias") + start_idx = self.rank * self.out_features_per_rank end_idx = start_idx + self.out_features_per_rank weight_slice = full_weight[start_idx:end_idx, :] self.weight.data.copy_(weight_slice) - + if self.bias is not None: bias_slice = full_bias[start_idx:end_idx] - self.bias.data.copy_(bias_slice) \ No newline at end of file + self.bias.data.copy_(bias_slice) diff --git a/khaosz/parallel/setup.py b/khaosz/parallel/setup.py index 4c81b19..111fde4 100644 --- a/khaosz/parallel/setup.py +++ b/khaosz/parallel/setup.py @@ -11,73 +11,74 @@ from typing import Callable, List, Optional def get_current_device(): return os.environ["LOCAL_DEVICE"] + def get_world_size() -> int: if dist.is_available() and dist.is_initialized(): return dist.get_world_size() else: return 1 + def get_rank() -> int: if dist.is_available() and dist.is_initialized(): return dist.get_rank() else: return 0 + @contextmanager def setup_parallel( - rank: int, + rank: int, world_size: int, backend: str = "nccl", master_addr: str = "localhost", master_port: str = "29500", device_type: str = "cuda", - device_ids: Optional[List[int]] = None + device_ids: Optional[List[int]] = None, ): if dist.is_available() and dist.is_initialized(): yield dist.group.WORLD - return + return if world_size <= 1: yield None return - + if device_ids is None: device_ids = [i for i in range(world_size)] - + rank = device_ids[rank % len(device_ids)] device_id = torch.device(device_type, device_ids[rank]) - - os.environ['MASTER_ADDR'] = master_addr - os.environ['MASTER_PORT'] = master_port - - os.environ['LOCAL_RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) + + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = master_port + + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) os.environ["LOCAL_DEVICE"] = str(device_id) - + dist.init_process_group( - rank=rank, - world_size=world_size, - backend=backend, - device_id=device_id + rank=rank, world_size=world_size, backend=backend, device_id=device_id ) - + try: if backend == "nccl" and torch.cuda.is_available(): torch.cuda.set_device(device_id) - elif backend == "ccl" and hasattr(torch, 'xpu') and torch.xpu.is_available(): + elif backend == "ccl" and hasattr(torch, "xpu") and torch.xpu.is_available(): torch.xpu.set_device(device_id) - + yield dist.group.WORLD finally: if dist.is_initialized(): dist.destroy_process_group() + def only_on_rank(rank, sync=False): """ decorator to run a function only on a specific rank. """ - + def decorator(func): @wraps(func) def wrapper(*args, **kwargs): @@ -89,67 +90,81 @@ def only_on_rank(rank, sync=False): dist.barrier() return ret_args - + return wrapper - + return decorator + def wrapper_spawn_func( - rank: int, - world_size: int, - backend: str, - master_addr: str, + rank: int, + world_size: int, + backend: str, + master_addr: str, master_port: str, device_type: str, - device_ids: List[int], - func: Callable, - kwargs: dict + device_ids: List[int], + func: Callable, + kwargs: dict, ): try: with setup_parallel( - rank=rank, - world_size=world_size, - backend=backend, - master_addr=master_addr, + rank=rank, + world_size=world_size, + backend=backend, + master_addr=master_addr, master_port=master_port, device_type=device_type, - device_ids=device_ids + device_ids=device_ids, ): func(**kwargs) - + except Exception as e: print(f"Error in rank {rank}: {e}") raise + def spawn_parallel_fn( - func: Callable, - world_size: int, + func: Callable, + world_size: int, backend: str = "nccl", master_addr: str = "localhost", master_port: str = "29500", device_type: str = "cuda", device_ids: Optional[List[int]] = None, - **kwargs + **kwargs, ): # clear environment variables - for key in ['MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE', 'LOCAL_RANK', 'LOCAL_DEVICE']: + for key in [ + "MASTER_ADDR", + "MASTER_PORT", + "RANK", + "WORLD_SIZE", + "LOCAL_RANK", + "LOCAL_DEVICE", + ]: if key in os.environ: del os.environ[key] - + if world_size == 1: device_ids = device_ids or [0] device_id = torch.device(device_type, device_ids[0]) os.environ["LOCAL_DEVICE"] = str(device_id) - + func(**kwargs) return - wrapper_spawn_func_args = (world_size, backend, master_addr, master_port, - device_type, device_ids, func, kwargs) + wrapper_spawn_func_args = ( + world_size, + backend, + master_addr, + master_port, + device_type, + device_ids, + func, + kwargs, + ) mp.spawn( - wrapper_spawn_func, - nprocs=world_size, - args=wrapper_spawn_func_args, - join=True - ) \ No newline at end of file + wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True + ) diff --git a/khaosz/trainer/__init__.py b/khaosz/trainer/__init__.py index af1bd3b..b266a66 100644 --- a/khaosz/trainer/__init__.py +++ b/khaosz/trainer/__init__.py @@ -14,15 +14,12 @@ from khaosz.trainer.train_callback import ( __all__ = [ # Main trainer "Trainer", - # Strategy factory "StrategyFactory", "BaseStrategy", - # Scheduler factory "SchedulerFactory", "BaseScheduler", - # Callbacks "TrainCallback", "GradientClippingCallback", @@ -30,4 +27,4 @@ __all__ = [ "CheckpointCallback", "ProgressBarCallback", "MetricLoggerCallback", -] \ No newline at end of file +] diff --git a/khaosz/trainer/metric_util.py b/khaosz/trainer/metric_util.py index 5d425f4..920b2a6 100644 --- a/khaosz/trainer/metric_util.py +++ b/khaosz/trainer/metric_util.py @@ -1,8 +1,9 @@ import torch.nn as nn from typing import Dict + def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]: - """ Compute gradient norm for each parameter in the model. """ + """Compute gradient norm for each parameter in the model.""" norms = {} for name, param in model.named_parameters(): norms[name] = 0.0 @@ -11,8 +12,9 @@ def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]: norms[name] = norm return norms + def grad_std(model: nn.Module) -> Dict[str, float]: - """ Compute standard deviation of gradients for each parameter. """ + """Compute standard deviation of gradients for each parameter.""" stds = {} for name, param in model.named_parameters(): stds[name] = 0.0 @@ -21,41 +23,45 @@ def grad_std(model: nn.Module) -> Dict[str, float]: stds[name] = std return stds + def grad_max(model: nn.Module) -> Dict[str, float]: - """ Find the maximum absolute gradient value for each parameter. """ + """Find the maximum absolute gradient value for each parameter.""" max_vals = {} for name, param in model.named_parameters(): - max_vals[name] = -float('inf') + max_vals[name] = -float("inf") if param.grad: max_val = param.grad.data.max().item() max_vals[name] = max_val - + return max_vals + def grad_min(model: nn.Module) -> Dict[str, float]: - """ Find the minimum absolute gradient value for each parameter. """ + """Find the minimum absolute gradient value for each parameter.""" min_vals = {} for name, param in model.named_parameters(): - min_vals[name] = float('inf') + min_vals[name] = float("inf") if param.grad: min_val = param.grad.data.min().item() min_vals[name] = min_val - + return min_vals + def grad_mean(model: nn.Module) -> Dict[str, float]: - """ Compute mean of gradients for each parameter. """ + """Compute mean of gradients for each parameter.""" means = {} for name, param in model.named_parameters(): means[name] = 0.0 if param.grad: mean = param.grad.data.mean().item() means[name] = mean - + return means + def grad_nan_num(model: nn.Module) -> Dict[str, int]: - """ Count the number of NaNs in gradients for each parameter. """ + """Count the number of NaNs in gradients for each parameter.""" nan_nums = {} for name, param in model.named_parameters(): nan_nums[name] = 0 @@ -64,26 +70,34 @@ def grad_nan_num(model: nn.Module) -> Dict[str, int]: nan_nums[name] = nan_num return nan_nums + def ctx_get_loss(ctx): return ctx.loss + def ctx_get_lr(ctx): - return ctx.optimizer.param_groups[-1]['lr'] + return ctx.optimizer.param_groups[-1]["lr"] + def ctx_get_grad_norm(ctx): return grad_norm(ctx.model) + def ctx_get_grad_std(ctx): return grad_std(ctx.model) + def ctx_get_grad_max(ctx): return grad_max(ctx.model) + def ctx_get_grad_min(ctx): return grad_min(ctx.model) + def ctx_get_grad_mean(ctx): return grad_mean(ctx.model) + def ctx_get_grad_nan_num(ctx): - return grad_nan_num(ctx.model) \ No newline at end of file + return grad_nan_num(ctx.model) diff --git a/khaosz/trainer/schedule.py b/khaosz/trainer/schedule.py index 8853690..453d2ee 100644 --- a/khaosz/trainer/schedule.py +++ b/khaosz/trainer/schedule.py @@ -9,71 +9,75 @@ from khaosz.config.schedule_config import ScheduleConfig class BaseScheduler(LRScheduler, ABC): """Base scheduler class for all other schedulers.""" - + def __init__(self, optimizer, last_epoch: int = -1): super().__init__(optimizer, last_epoch) - + @abstractmethod def get_lr(self) -> List[float]: """Calculate the current learning rate.""" raise NotImplementedError - + def state_dict(self) -> Dict[str, Any]: return super().state_dict() - + def load_state_dict(self, state_dict: Dict[str, Any]): super().load_state_dict(state_dict) class SchedulerFactory: """Factory class for creating learning rate schedulers. - + Supports decorator-based registration for extensible scheduler types. Also supports creation from ScheduleConfig objects. - + Example usage: @SchedulerFactory.register("custom") class CustomScheduler(BaseScheduler): ... - + scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs) - + # Or from config config = CosineScheduleConfig(total_steps=10000) scheduler = SchedulerFactory.load(optimizer, config) """ - + SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {} - + @classmethod def register(cls, name: str): """Decorator to register a new scheduler class. - + Args: name: Registration name for the scheduler - + Returns: Decorator function that registers the scheduler class """ + def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]: if not issubclass(scheduler_cls, BaseScheduler): - raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler") + raise TypeError( + f"{scheduler_cls.__name__} must inherit from BaseScheduler" + ) cls.SCHEDULER_MAP[name] = scheduler_cls return scheduler_cls + return decorator - + @classmethod def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler: """Create a scheduler instance by type name. - + Args: optimizer: PyTorch optimizer schedule_type: Type of scheduler ("cosine", "sgdr") **kwargs: Arguments passed to the scheduler constructor - + Returns: Scheduler instance - + Raises: ValueError: If schedule_type is not supported """ @@ -82,25 +86,25 @@ class SchedulerFactory: f"Unknown schedule type: '{schedule_type}'. " f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}" ) - + scheduler_cls = cls.SCHEDULER_MAP[schedule_type] return scheduler_cls(optimizer, **kwargs) - + @staticmethod def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler: """Create a scheduler from a ScheduleConfig object. - + Args: optimizer: PyTorch optimizer schedule_config: ScheduleConfig instance - + Returns: Scheduler instance """ kwargs = schedule_config.get_kwargs() schedule_type = kwargs.pop("schedule_type") return SchedulerFactory.create(optimizer, schedule_type, **kwargs) - + @classmethod def available_types(cls) -> list: """Return list of registered scheduler type names.""" @@ -114,22 +118,21 @@ class SchedulerFactory: @SchedulerFactory.register("cosine") class CosineScheduler(BaseScheduler): """Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler.""" - + def __init__( - self, - optimizer, - warmup_steps: int, - lr_decay_steps: int, - min_rate: float = 0.05, - last_epoch: int = -1 + self, + optimizer, + warmup_steps: int, + lr_decay_steps: int, + min_rate: float = 0.05, + last_epoch: int = -1, ): self.warmup_steps = warmup_steps self.lr_decay_steps = lr_decay_steps self.min_rate = min_rate self.total_steps = warmup_steps + lr_decay_steps super().__init__(optimizer, last_epoch) - - + def get_lr(self) -> List[float]: # warmup if self.last_epoch < self.warmup_steps: @@ -142,46 +145,47 @@ class CosineScheduler(BaseScheduler): cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress)) decay_factor = max(self.min_rate, cosine_decay) return [base_lr * decay_factor for base_lr in self.base_lrs] - + def state_dict(self): state = super().state_dict() - state.update({ - 'warmup_steps': self.warmup_steps, - 'lr_decay_steps': self.lr_decay_steps, - 'min_rate': self.min_rate, - 'total_steps': self.total_steps, - }) + state.update( + { + "warmup_steps": self.warmup_steps, + "lr_decay_steps": self.lr_decay_steps, + "min_rate": self.min_rate, + "total_steps": self.total_steps, + } + ) return state - + def load_state_dict(self, state_dict): - self.warmup_steps = state_dict.pop('warmup_steps') - self.lr_decay_steps = state_dict.pop('lr_decay_steps') - self.min_rate = state_dict.pop('min_rate') - self.total_steps = state_dict.pop('total_steps') + self.warmup_steps = state_dict.pop("warmup_steps") + self.lr_decay_steps = state_dict.pop("lr_decay_steps") + self.min_rate = state_dict.pop("min_rate") + self.total_steps = state_dict.pop("total_steps") super().load_state_dict(state_dict) @SchedulerFactory.register("sgdr") class SGDRScheduler(BaseScheduler): """SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler.""" - + def __init__( - self, - optimizer, - warmup_steps: int, - cycle_length: int, - min_rate: float = 0.05, - t_mult: int = 2, - last_epoch: int = -1, + self, + optimizer, + warmup_steps: int, + cycle_length: int, + min_rate: float = 0.05, + t_mult: int = 2, + last_epoch: int = -1, ): self.warmup_steps = warmup_steps self.cycle_length = cycle_length self.min_rate = min_rate self.t_mult = t_mult - + super().__init__(optimizer, last_epoch) - def get_lr(self): # warmup if self.last_epoch < self.warmup_steps: @@ -190,40 +194,44 @@ class SGDRScheduler(BaseScheduler): # SGDR steps_since_warmup = self.last_epoch - self.warmup_steps - + # 1. Calculate current cycle and position within cycle current_cycle_length = self.cycle_length total_cycles_length = 0 cycle_num = 0 - + while total_cycles_length + current_cycle_length <= steps_since_warmup: total_cycles_length += current_cycle_length current_cycle_length *= self.t_mult cycle_num += 1 - + steps_in_cycle = steps_since_warmup - total_cycles_length - + # 2. Cosine annealing within the current cycle - cosine_factor = 0.5 * (1 + math.cos(math.pi * steps_in_cycle / current_cycle_length)) + cosine_factor = 0.5 * ( + 1 + math.cos(math.pi * steps_in_cycle / current_cycle_length) + ) learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor - + return [base_lr * learning_rate_factor for base_lr in self.base_lrs] - + def state_dict(self): """Returns the state of the scheduler as a dict.""" state = super().state_dict() - state.update({ - 'warmup_steps': self.warmup_steps, - 'cycle_length': self.cycle_length, - 'min_rate': self.min_rate, - 't_mult': self.t_mult - }) + state.update( + { + "warmup_steps": self.warmup_steps, + "cycle_length": self.cycle_length, + "min_rate": self.min_rate, + "t_mult": self.t_mult, + } + ) return state - + def load_state_dict(self, state_dict): """Loads the scheduler's state.""" - self.warmup_steps = state_dict.pop('warmup_steps') - self.cycle_length = state_dict.pop('cycle_length') - self.min_rate = state_dict.pop('min_rate') - self.t_mult = state_dict.pop('t_mult') - super().load_state_dict(state_dict) \ No newline at end of file + self.warmup_steps = state_dict.pop("warmup_steps") + self.cycle_length = state_dict.pop("cycle_length") + self.min_rate = state_dict.pop("min_rate") + self.t_mult = state_dict.pop("t_mult") + super().load_state_dict(state_dict) diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 85679d2..db877be 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -20,7 +20,7 @@ def unwrap_model(model: nn.Module) -> nn.Module: def create_ref_model(model: nn.Module) -> nn.Module: """Create a reference model for DPO/GRPO training. - + Handles DDP-wrapped models safely by unwrapping first, then creating a deep copy with frozen gradients. """ @@ -37,25 +37,27 @@ def move_to_device(batch: Dict[str, Tensor], device: str) -> Any: def get_logprobs( - model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], - input_ids: Tensor, - mask: Tensor, + model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], + input_ids: Tensor, + mask: Tensor, reduction: str, ): """Compute token-wise log probabilities from model outputs. - + Args: model: The language model input_ids: Input token IDs of shape [batch_size, seq_len] mask: Attention mask of shape [batch_size, seq_len] reduction: How to reduce over sequence dimension ("mean", "sum", "none") - + Returns: Log probabilities with reduction applied over sequence dimension """ allowed_reductions = ["mean", "sum", "none"] if reduction not in allowed_reductions: - raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'") + raise ValueError( + f"reduction must be one of {allowed_reductions}, got '{reduction}'" + ) shifted_input_ids = input_ids[:, 1:] shifted_mask = mask[:, 1:] @@ -64,13 +66,13 @@ def get_logprobs( log_probs = torch.log_softmax(logits.float(), dim=-1) token_logprobs = torch.gather( - log_probs, - dim=-1, - index=shifted_input_ids.unsqueeze(-1) + log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1) ).squeeze(-1) - + if reduction == "mean": - return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(dim=-1).clamp(min=1.0) + return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum( + dim=-1 + ).clamp(min=1.0) elif reduction == "sum": return (token_logprobs * shifted_mask).sum(dim=-1) else: @@ -79,23 +81,25 @@ def get_logprobs( class BaseStrategy(ABC): """Abstract base class for training strategies.""" - - def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str): + + def __init__( + self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str + ): self.model = model self.device = device - + @abstractmethod def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: """Compute loss for the given batch. - + Args: batch: Dictionary containing batch tensors - + Returns: Computed loss tensor """ raise NotImplementedError - + def __call__(self, batch: Dict[str, Tensor]) -> Tensor: """Allow calling strategy directly as a callable.""" return self.compute_loss(batch) @@ -103,51 +107,55 @@ class BaseStrategy(ABC): class StrategyFactory: """Factory class for creating training strategy instances. - + Supports decorator-based registration for extensible strategy types. All default strategies (seq, sft, dpo, grpo) are automatically registered. - + Example usage: @StrategyFactory.register("custom") class CustomStrategy(BaseStrategy): ... - + strategy = StrategyFactory.create(model, "custom", device) """ - + SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"}) STRATEGY_MAP: Dict[str, type] = {} - + @classmethod def register(cls, name: str): """Decorator to register a new strategy class. - + Args: name: Registration name for the strategy - + Returns: Decorator function that registers the strategy class """ + def decorator(strategy_cls: type) -> type: if not issubclass(strategy_cls, BaseStrategy): - raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy") + raise TypeError( + f"{strategy_cls.__name__} must inherit from BaseStrategy" + ) cls.STRATEGY_MAP[name] = strategy_cls return strategy_cls + return decorator - + @classmethod def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy: """Create a strategy instance based on training type. - + Args: model: Model instance for the strategy train_type: Type of training ("seq", "sft", "dpo", "grpo") device: Device to run the strategy on **kwargs: Additional arguments passed to strategy constructor - + Returns: Strategy instance - + Raises: ValueError: If train_type is not supported NotImplementedError: If train_type is in supported list but not implemented @@ -157,15 +165,15 @@ class StrategyFactory: f"Unknown training strategy: '{train_type}'. " f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}" ) - + if train_type not in cls.STRATEGY_MAP: raise NotImplementedError( f"Strategy '{train_type}' is supported but not yet implemented." ) - + strategy_cls = cls.STRATEGY_MAP[train_type] return strategy_cls(model, device, **kwargs) - + @classmethod def available_strategies(cls) -> list: """Return list of registered strategy names.""" @@ -179,77 +187,81 @@ class StrategyFactory: @StrategyFactory.register("seq") class SEQStrategy(BaseStrategy): """Standard next-token prediction training strategy. - + Computes cross-entropy loss for next token prediction. """ - + def __init__(self, model, device, label_smoothing: float = 0.0): super().__init__(model, device) self.label_smoothing = label_smoothing - + def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) input_ids, target_ids = batch["input_ids"], batch["target_ids"] logits = self.model(input_ids=input_ids)["logits"] - + loss = F.cross_entropy( input=logits.flatten(0, 1).float(), target=target_ids.flatten(), - label_smoothing=self.label_smoothing + label_smoothing=self.label_smoothing, ) - + return loss @StrategyFactory.register("sft") class SFTStrategy(BaseStrategy): """Supervised Fine-tuning strategy with loss masking. - + Applies cross-entropy loss only to tokens where loss_mask is True. """ - + def __init__(self, model, device, label_smoothing: float = 0.0): super().__init__(model, device) self.label_smoothing = label_smoothing - + def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) - input_ids, target_ids, loss_mask = batch["input_ids"], batch["target_ids"], batch["loss_mask"] - + input_ids, target_ids, loss_mask = ( + batch["input_ids"], + batch["target_ids"], + batch["loss_mask"], + ) + ignore_index = -100 logits = self.model(input_ids=input_ids)["logits"] target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) - + loss = F.cross_entropy( input=logits.flatten(0, 1).float(), target=target_ids.flatten(), ignore_index=ignore_index, - label_smoothing=self.label_smoothing + label_smoothing=self.label_smoothing, ) - + return loss @StrategyFactory.register("dpo") class DPOStrategy(BaseStrategy): """Direct Preference Optimization strategy. - + Implements the DPO loss from the paper "Direct Preference Optimization". Uses a reference model to compute KL divergence penalty. """ - + def __init__( - self, - model: nn.Module, - device: str, - beta: float = 0.1, - reduction: str = "mean", - ): + self, + model: nn.Module, + device: str, + beta: float = 0.1, + reduction: str = "mean", + ): super().__init__(model, device) self.ref_model = create_ref_model(model) self.beta = beta self.reduction = reduction - + def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) chosen_ids, rejected_ids = batch["chosen"], batch["rejected"] @@ -257,17 +269,19 @@ class DPOStrategy(BaseStrategy): contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0) contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0) - + log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction) with torch.no_grad(): - log_ref = get_logprobs(self.ref_model, contact_ids, contact_mask, self.reduction) - - log_pi_chosen = log_pi[:chosen_ids.shape[0]] - log_pi_rejected = log_pi[chosen_ids.shape[0]:] - log_ref_chosen = log_ref[:chosen_ids.shape[0]] - log_ref_rejected = log_ref[chosen_ids.shape[0]:] - + log_ref = get_logprobs( + self.ref_model, contact_ids, contact_mask, self.reduction + ) + + log_pi_chosen = log_pi[: chosen_ids.shape[0]] + log_pi_rejected = log_pi[chosen_ids.shape[0] :] + log_ref_chosen = log_ref[: chosen_ids.shape[0]] + log_ref_rejected = log_ref[chosen_ids.shape[0] :] + pi_log_ratio = log_pi_chosen - log_pi_rejected ref_log_ratio = log_ref_chosen - log_ref_rejected @@ -280,14 +294,14 @@ class DPOStrategy(BaseStrategy): @StrategyFactory.register("grpo") class GRPOStrategy(BaseStrategy): """Group Relative Policy Optimization strategy. - + Implements GRPO with clipping and KL penalty. """ - + def __init__( - self, - model: nn.Module, - device: str, + self, + model: nn.Module, + device: str, clip_eps: float = 0.2, kl_coef: float = 0.01, group_size: int = 4, @@ -299,43 +313,47 @@ class GRPOStrategy(BaseStrategy): self.kl_coef = kl_coef self.group_size = group_size self.reduction = reduction - + def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) prompts = batch["prompts"] responses = batch["responses"] masks = batch["masks"] rewards = batch["rewards"] - + batch_size, group_size, response_len = responses.shape responses_flat = responses.view(-1, response_len) masks_flat = masks.view(-1, response_len) prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1) - + # Shape: (batch_size * group_size, seq_len + response_len) full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1) full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1) - - log_probs_policy = get_logprobs(self.model, full_sequences, full_masks, self.reduction) + + log_probs_policy = get_logprobs( + self.model, full_sequences, full_masks, self.reduction + ) log_probs_policy = log_probs_policy.view(batch_size, group_size) - + with torch.no_grad(): - log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction) + log_probs_ref = get_logprobs( + self.ref_model, full_sequences, full_masks, self.reduction + ) log_probs_ref = log_probs_ref.view(batch_size, group_size) - + # Compute advantages from rewards with normalization eps = torch.finfo(log_probs_policy.dtype).eps mean = rewards.mean(dim=-1, keepdim=True) std = rewards.std(dim=-1, keepdim=True) advantages = (rewards - mean) / (std + eps) - + # PPO-style clipped surrogate objective ratio = torch.exp(0) # Off-policy: policy_model = old_model surr1 = ratio * advantages surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages - + policy_loss = -torch.min(surr1, surr2).mean() kl_penalty = self.kl_coef * (log_probs_policy - log_probs_ref).square().mean() total_loss = policy_loss + kl_penalty - + return total_loss diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 986b72a..0b1a94a 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -18,52 +18,53 @@ from khaosz.trainer.metric_util import ( ctx_get_grad_norm, ctx_get_grad_mean, ctx_get_grad_std, - ctx_get_grad_nan_num + ctx_get_grad_nan_num, ) from khaosz.data.serialization import Checkpoint from khaosz.trainer.train_context import TrainContext class TrainCallback(Protocol): - """ + """ Callback interface for trainer. """ def on_train_begin(self, context: TrainContext): - """ Called at the beginning of training. """ - + """Called at the beginning of training.""" + def on_train_end(self, context: TrainContext): - """ Called at the end of training. """ + """Called at the end of training.""" def on_epoch_begin(self, context: TrainContext): - """ Called at the beginning of each epoch. """ + """Called at the beginning of each epoch.""" def on_epoch_end(self, context: TrainContext): - """ Called at the end of each epoch. """ - + """Called at the end of each epoch.""" + def on_step_begin(self, context: TrainContext): - """ Called at the beginning of each step. """ + """Called at the beginning of each step.""" def on_step_end(self, context: TrainContext): - """ Called at the end of each step.""" + """Called at the end of each step.""" def on_batch_begin(self, context: TrainContext): - """ Called at the beginning of each batch. """ + """Called at the beginning of each batch.""" def on_batch_end(self, context: TrainContext): - """ Called at the end of each batch. """ - + """Called at the end of each batch.""" + def on_error(self, context: TrainContext): - """ Called when an error occurs during training. """ + """Called when an error occurs during training.""" class GradientClippingCallback(TrainCallback): - """ + """ Gradient clipping callback for trainer. """ + def __init__(self, max_grad_norm: float): self.max_grad_norm = max_grad_norm - + def on_step_begin(self, context: TrainContext): _ = context clip_grad_norm_(context.model.parameters(), self.max_grad_norm) @@ -73,86 +74,95 @@ class SchedulerCallback(TrainCallback): """ Scheduler callback for trainer. """ + def __init__(self): pass - + def on_train_begin(self, context: TrainContext): for group in context.optimizer.param_groups: if "initial_lr" not in group: - group["initial_lr"] = group["lr"] - + group["initial_lr"] = group["lr"] + def on_batch_end(self, context: TrainContext): if context.scheduler: context.scheduler.step() class CheckpointCallback(TrainCallback): - """ + """ Checkpoint callback for trainer. """ + def __init__( - self, - save_dir: str, + self, + save_dir: str, interval: int, weight_only: bool = False, - state_dict_fn: Optional[Callable[[nn.Module], dict]] = None + state_dict_fn: Optional[Callable[[nn.Module], dict]] = None, ): self.save_dir = save_dir self.interval = interval self.weight_only = weight_only self.state_dict_fn = state_dict_fn self.last_ckpt_iter = 0 - + @only_on_rank(0) def _save_checkpoint(self, context: TrainContext): - save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}") - state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict() - + save_path = os.path.join( + self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}" + ) + state_dict = ( + self.state_dict_fn(context.model) + if self.state_dict_fn + else context.model.state_dict() + ) + context.checkpoint = Checkpoint( - state_dict=state_dict, - epoch=context.epoch, - iteration=context.iteration + state_dict=state_dict, epoch=context.epoch, iteration=context.iteration ) context.checkpoint.save(save_path) self.last_ckpt_iter = context.iteration - + def on_batch_end(self, context: TrainContext): if context.iteration - self.last_ckpt_iter >= self.interval: self._save_checkpoint(context) - + def on_train_end(self, context: TrainContext): if context.iteration != self.last_ckpt_iter: self._save_checkpoint(context) - + def on_error(self, context: TrainContext): self._save_checkpoint(context) class ProgressBarCallback(TrainCallback): - """ + """ Progress bar callback for trainer. """ + def __init__(self, num_epoch: int): self.num_epoch = num_epoch self.progress_bar: tqdm = None - + @only_on_rank(0) def on_epoch_begin(self, context: TrainContext): self.progress_bar = tqdm( - context.dataloader, - desc=f"Epoch {context.epoch+1}/{self.num_epoch}", - dynamic_ncols=True + context.dataloader, + desc=f"Epoch {context.epoch + 1}/{self.num_epoch}", + dynamic_ncols=True, ) - + @only_on_rank(0) def on_batch_end(self, context: TrainContext): - self.progress_bar.set_postfix({ - "loss": f"{context.loss:.4f}", - "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}" - }) + self.progress_bar.set_postfix( + { + "loss": f"{context.loss:.4f}", + "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}", + } + ) self.progress_bar.update(1) - + @only_on_rank(0) def on_epoch_end(self, context: TrainContext): _ = context @@ -162,66 +172,65 @@ class ProgressBarCallback(TrainCallback): class MetricLoggerCallback(TrainCallback): def __init__( - self, - log_dir:str, - save_interval:int, - log_interval:int=10, - metrics:List[str]=None + self, + log_dir: str, + save_interval: int, + log_interval: int = 10, + metrics: List[str] = None, ): self.last_log_iter = 0 self.save_interval = save_interval self.log_interval = log_interval - self.metrics = metrics or ['loss', 'lr'] - + self.metrics = metrics or ["loss", "lr"] + self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs" self.log_dir.mkdir(parents=True, exist_ok=True) - + self.log_cache = [] - + self._metric_funcs = { - 'loss': ctx_get_loss, - 'lr': ctx_get_lr, - 'grad_norm': ctx_get_grad_norm, - 'grad_std': ctx_get_grad_std, - 'grad_max': ctx_get_grad_max, - 'grad_min': ctx_get_grad_min, - 'grad_mean': ctx_get_grad_mean, - 'grad_nan_num': ctx_get_grad_nan_num + "loss": ctx_get_loss, + "lr": ctx_get_lr, + "grad_norm": ctx_get_grad_norm, + "grad_std": ctx_get_grad_std, + "grad_max": ctx_get_grad_max, + "grad_min": ctx_get_grad_min, + "grad_mean": ctx_get_grad_mean, + "grad_nan_num": ctx_get_grad_nan_num, } def _get_log_data(self, context: TrainContext): return { - "timestamp": time.strftime('%Y-%m-%d %H:%M:%S'), + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "epoch": context.epoch, "iter": context.iteration, - **{m: self._metric_funcs[m](context) for m in self.metrics} + **{m: self._metric_funcs[m](context) for m in self.metrics}, } - + @only_on_rank(0) def _add_log(self, log_data): self.log_cache.append(log_data) - + @only_on_rank(0) def _save_log(self, epoch, iter): log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl" - - with open(log_file, 'w') as f: + + with open(log_file, "w") as f: for log in self.log_cache: - f.write(json.dumps(log) + '\n') - + f.write(json.dumps(log) + "\n") + def on_batch_end(self, context): if context.iteration % self.log_interval == 0: log_data = self._get_log_data(context) self._add_log(log_data) - + if context.iteration - self.last_log_iter >= self.save_interval: self._save_log(context.epoch, context.iteration) self.last_log_iter = context.iteration - + def on_train_end(self, context): if context.iteration != self.last_log_iter: self._save_log(context.epoch, context.iteration) - + def on_error(self, context): self._save_log(context.epoch, context.iteration) - \ No newline at end of file diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index 0f7e08c..f718ffd 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -21,11 +21,11 @@ class TrainContext: optimizer: Optimizer = field(default=None) scheduler: LRScheduler = field(default=None) checkpoint: Checkpoint = field(default=None) - + epoch: int = field(default=0) iteration: int = field(default=0) loss: float = field(default=0.0) - + world_size: int = field(default=1) rank: int = field(default=0) kwargs: dict = field(default_factory=dict) @@ -39,17 +39,17 @@ class TrainContextBuilder: world_size=get_world_size(), rank=get_rank(), ) - + device = get_current_device() self._context.model = self._context.model.to(device=device) - + if self.config.nprocs > 1: fn = self.config.parallel_wrapper self._context.model = fn(self._context.model) - + self._context.optimizer = self.config.optimizer_fn(self._context.model) self._context.scheduler = self.config.scheduler_fn(self._context.optimizer) - + def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: if checkpoint is None: checkpoint = Checkpoint( @@ -60,10 +60,10 @@ class TrainContextBuilder: self._context.epoch = max(checkpoint.epoch, self.config.start_epoch) self._context.iteration = max(checkpoint.iteration, self.config.start_batch) self._context.model.load_state_dict(checkpoint.state_dict) - + self._context.checkpoint = checkpoint return self - + def with_dataloader(self) -> Self: # fix: change batch level iteration to sample level offset config = self.config @@ -72,28 +72,28 @@ class TrainContextBuilder: data_source=config.dataset, start_epoch=self._context.epoch, start_iter=sampler_offset, - seed=config.random_seed + seed=config.random_seed, ) - + dataloader = DataLoader( - config.dataset, - batch_size=config.batch_size, + config.dataset, + batch_size=config.batch_size, sampler=resumeable_sampler, num_workers=config.num_workers, pin_memory=config.pin_memory, - prefetch_factor=config.prefetch_factor + prefetch_factor=config.prefetch_factor, ) self._context.dataloader = dataloader return self - + def with_strategy(self) -> Self: self._context.strategy = StrategyFactory.create( model=self._context.model, train_type=self.config.strategy, device=get_current_device(), - **self.config.extra_kwargs + **self.config.extra_kwargs, ) return self - + def build(self) -> TrainContext: - return self._context \ No newline at end of file + return self._context diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 6d2415f..179fdd4 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -2,12 +2,12 @@ import logging from typing import Optional, List from khaosz.config import TrainConfig from khaosz.trainer.train_callback import ( - TrainCallback, - ProgressBarCallback, + TrainCallback, + ProgressBarCallback, CheckpointCallback, MetricLoggerCallback, GradientClippingCallback, - SchedulerCallback + SchedulerCallback, ) from khaosz.trainer.train_context import TrainContext, TrainContextBuilder from khaosz.data.serialization import Checkpoint @@ -18,37 +18,39 @@ logger = logging.getLogger(__name__) class Trainer: def __init__( - self, - train_config: TrainConfig, - callbacks: Optional[List[TrainCallback]] = None + self, train_config: TrainConfig, callbacks: Optional[List[TrainCallback]] = None ): self.train_config = train_config default_callbacks = self._get_default_callbacks() - self.callbacks = default_callbacks + callbacks if callbacks else default_callbacks + self.callbacks = ( + default_callbacks + callbacks if callbacks else default_callbacks + ) def _get_default_callbacks(self) -> List[TrainCallback]: train_config = self.train_config return [ ProgressBarCallback(train_config.n_epoch), - CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), - MetricLoggerCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), + CheckpointCallback(train_config.ckpt_dir, train_config.ckpt_interval), + MetricLoggerCallback(train_config.ckpt_dir, train_config.ckpt_interval), GradientClippingCallback(train_config.max_grad_norm), SchedulerCallback(), ] - + def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: - return (TrainContextBuilder(self.train_config) - .with_checkpoint(checkpoint) - .with_dataloader() - .with_strategy() - .build()) - + return ( + TrainContextBuilder(self.train_config) + .with_checkpoint(checkpoint) + .with_dataloader() + .with_strategy() + .build() + ) + def _call_callbacks(self, method_name: str, context: TrainContext): for callback in self.callbacks: method = getattr(callback, method_name, None) if method: method(context) - + def train(self, checkpoint: Optional[Checkpoint] = None): config = self.train_config spawn_parallel_fn( @@ -59,45 +61,45 @@ class Trainer: master_port=config.master_port, device_type=config.device_type, device_ids=config.device_ids, - checkpoint=checkpoint + checkpoint=checkpoint, ) def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint: context = self._build_context(checkpoint) - self._call_callbacks('on_train_begin', context) - + self._call_callbacks("on_train_begin", context) + try: context.model.train() # 1.epoch for epoch in range(context.epoch, self.train_config.n_epoch): context.epoch = epoch - self._call_callbacks('on_epoch_begin', context) - + self._call_callbacks("on_epoch_begin", context) + for batch in context.dataloader: # 3. batch - self._call_callbacks('on_batch_begin', context) + self._call_callbacks("on_batch_begin", context) loss = context.strategy(batch) context.loss = loss.item() context.iteration += 1 - + # to make the loss normalized by accumulation steps stand_loss = loss / self.train_config.accumulation_steps stand_loss.backward() - - self._call_callbacks('on_batch_end', context) + + self._call_callbacks("on_batch_end", context) if context.iteration % self.train_config.accumulation_steps == 0: # 2. step - self._call_callbacks('on_step_begin', context) + self._call_callbacks("on_step_begin", context) context.optimizer.step() context.optimizer.zero_grad() - self._call_callbacks('on_step_end', context) + self._call_callbacks("on_step_end", context) + + self._call_callbacks("on_epoch_end", context) - self._call_callbacks('on_epoch_end', context) - except Exception as e: logger.error(f"Training failed: {str(e)}", exc_info=True) - self._call_callbacks('on_error', context) + self._call_callbacks("on_error", context) raise finally: - self._call_callbacks('on_train_end', context) \ No newline at end of file + self._call_callbacks("on_train_end", context) diff --git a/pyproject.toml b/pyproject.toml index 2e89daf..bd0a69c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ urls = { Homepage = "https://github.com/ViperEkura/KHAOSZ" } [project.optional-dependencies] -dev = ["pytest==9.0.2"] +dev = ["pytest==9.0.2", "ruff"] [tool.setuptools.packages.find] where = ["."] @@ -35,4 +35,13 @@ where = ["."] extra-index-url = "https://download.pytorch.org/whl/cu126" [tool.setuptools.dynamic] -version = { attr = "khaosz.__version__" } \ No newline at end of file +version = { attr = "khaosz.__version__" } + +[tool.ruff] +target-version = "py312" + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 2dacbbb..b21fa92 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,14 +17,14 @@ class RandomDataset(Dataset): self.length = length or int(np.random.randint(100, 200)) self.max_length = max_length self.vocab_size = vocab_size - + def __len__(self): return self.length - + def __getitem__(self, idx): return { "input_ids": torch.randint(0, self.vocab_size, (self.max_length,)), - "target_ids": torch.randint(0, self.vocab_size, (self.max_length,)) + "target_ids": torch.randint(0, self.vocab_size, (self.max_length,)), } @@ -33,10 +33,10 @@ class MultiTurnDataset(Dataset): self.length = length or int(np.random.randint(100, 200)) self.max_length = max_length self.vocab_size = vocab_size - + def __len__(self): return self.length - + def __getitem__(self, idx): input_ids = torch.randint(0, self.vocab_size, (self.max_length,)) target_ids = torch.randint(0, self.vocab_size, (self.max_length,)) @@ -54,18 +54,18 @@ class EarlyStoppingDataset(Dataset): self.length = length self.stop_after = stop_after self.count = 0 - + def __len__(self): return self.length - + def __getitem__(self, idx): self.count += 1 if self.count == self.stop_after: raise RuntimeError("Simulated early stopping") - + return { "input_ids": torch.randint(0, 1000, (64,)), - "target_ids": torch.randint(0, 1000, (64,)) + "target_ids": torch.randint(0, 1000, (64,)), } @@ -74,10 +74,10 @@ def base_test_env(request: pytest.FixtureRequest): func_name = request.function.__name__ test_dir = tempfile.mkdtemp(prefix=f"{func_name}_") config_path = os.path.join(test_dir, "config.json") - + n_dim_choices = [8, 16, 32] n_head_choices = [2, 4] - + dim = int(np.random.choice(n_dim_choices)) n_heads = int(np.random.choice(n_head_choices)) n_kv_heads = n_heads // 2 @@ -91,16 +91,16 @@ def base_test_env(request: pytest.FixtureRequest): "dim_ffn": dim_ffn, "max_len": 1024, "n_layers": 4, - "norm_eps": 1e-5 + "norm_eps": 1e-5, } - - with open(config_path, 'w') as f: + + with open(config_path, "w") as f: json.dump(config, f) device = "cuda" if torch.cuda.is_available() else "cpu" transformer_config = ModelConfig().load(config_path) model = Transformer(transformer_config).to(device=device) tokenizer = BpeTokenizer() - + yield { "device": device, "test_dir": str(test_dir), @@ -109,20 +109,23 @@ def base_test_env(request: pytest.FixtureRequest): "model": model, "tokenizer": tokenizer, } - + shutil.rmtree(test_dir) + @pytest.fixture def random_dataset(): dataset = RandomDataset() yield dataset + @pytest.fixture def multi_turn_dataset(): dataset = MultiTurnDataset() yield dataset - + + @pytest.fixture def early_stopping_dataset(): dataset = EarlyStoppingDataset() - yield dataset \ No newline at end of file + yield dataset diff --git a/tests/data/test_checkpoint.py b/tests/data/test_checkpoint.py index 57071e6..abb9b33 100644 --- a/tests/data/test_checkpoint.py +++ b/tests/data/test_checkpoint.py @@ -7,6 +7,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR from khaosz.data.serialization import Checkpoint from khaosz.parallel.setup import get_rank, spawn_parallel_fn + def test_single_process(): model = torch.nn.Linear(10, 5) optimizer = AdamW(model.parameters(), lr=1e-3) @@ -14,34 +15,31 @@ def test_single_process(): for epoch in range(3): for iteration in range(10): - x = torch.randn(32, 10) y = torch.randn(32, 5) loss = model(x).mean() loss.backward() optimizer.step() optimizer.zero_grad() - + scheduler.step() - - checkpoint = Checkpoint( - state_dict=model.state_dict(), - epoch=3, - iteration=30 - ) - + + checkpoint = Checkpoint(state_dict=model.state_dict(), epoch=3, iteration=30) + with tempfile.TemporaryDirectory() as tmpdir: checkpoint.save(tmpdir) - + loaded_checkpoint = Checkpoint.load(tmpdir) - + assert loaded_checkpoint.epoch == 3 assert loaded_checkpoint.iteration == 30 + + def simple_training(): model = torch.nn.Linear(10, 5) optimizer = AdamW(model.parameters(), lr=1e-3) scheduler = CosineAnnealingLR(optimizer, T_max=10) - + for epoch in range(2): for iteration in range(5): x = torch.randn(16, 10) @@ -57,28 +55,23 @@ def simple_training(): epoch=2, iteration=10, ) - + rank = get_rank() - + if rank == 0: shared_dir = tempfile.mkdtemp() checkpoint.save(shared_dir) else: shared_dir = None - if dist.is_initialized(): dir_list = [shared_dir] dist.broadcast_object_list(dir_list, src=0) shared_dir = dir_list[0] - - + loaded = Checkpoint.load(shared_dir) assert loaded.epoch == 2 + def test_multi_process(): - spawn_parallel_fn( - simple_training, - world_size=2, - backend="gloo" - ) \ No newline at end of file + spawn_parallel_fn(simple_training, world_size=2, backend="gloo") diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index a9aac47..0420e23 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -5,30 +5,32 @@ from khaosz.data.serialization import save_h5 from khaosz.data.dataset import * - def test_dataset_loader_random_paths(base_test_env): """Test dataset loader with multiple random paths""" test_dir = base_test_env["test_dir"] - + # Create multiple mmap dataset directories with random data num_files = np.random.randint(2, 5) - + for i in range(num_files): seq_length = np.random.randint(200, 400) dummy_data = { - "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64) for _ in range(10)], + "sequence": [ + torch.randint(0, 1000, (seq_length,), dtype=torch.int64) + for _ in range(10) + ], } save_h5(test_dir, f"data_{i}", dummy_data) - + # Test loading with multiple paths loaded_dataset = DatasetLoader.load( - train_type="seq", - load_path=test_dir, - window_size=64, + train_type="seq", + load_path=test_dir, + window_size=64, ) assert loaded_dataset is not None assert len(loaded_dataset) > 0 - + # Test that we can get items without errors for i in range(len(loaded_dataset)): item = loaded_dataset[i] @@ -41,30 +43,30 @@ def test_dataset_loader_random_paths(base_test_env): def test_dpo_strategy_with_random_data(base_test_env): """Test DPO strategy with randomized preference data""" test_dir = base_test_env["test_dir"] - + # Create DPO-style data with memory mapping format seq_length = np.random.randint(100, 200) - + dummy_data = { "chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], "rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], "chosen_mask": [torch.ones(seq_length, dtype=torch.bool)], - "rejected_mask": [torch.ones(seq_length, dtype=torch.bool)] + "rejected_mask": [torch.ones(seq_length, dtype=torch.bool)], } - + save_h5(test_dir, "dpo_data", dummy_data) - + # Load DPO dataset dpo_dataset = DatasetLoader.load( - train_type="dpo", - load_path=test_dir, - window_size=64, + train_type="dpo", + load_path=test_dir, + window_size=64, ) - + assert dpo_dataset is not None - assert hasattr(dpo_dataset, 'fetcher') + assert hasattr(dpo_dataset, "fetcher") assert len(dpo_dataset) > 0 - + # Test that we can get DPO items without errors for i in range(min(3, len(dpo_dataset))): item = dpo_dataset[i] @@ -79,28 +81,28 @@ def test_dpo_strategy_with_random_data(base_test_env): def test_sft_dataset_with_random_data(base_test_env): """Test SFT dataset with random data""" test_dir = base_test_env["test_dir"] - + # Create SFT-style data with memory mapping format seq_length = np.random.randint(100, 200) - + dummy_data = { "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], - "loss_mask": [torch.ones(seq_length, dtype=torch.bool)] + "loss_mask": [torch.ones(seq_length, dtype=torch.bool)], } - + save_h5(test_dir, "sft_data", dummy_data) - + # Load SFT dataset sft_dataset = DatasetLoader.load( - train_type="sft", - load_path=test_dir, - window_size=64, + train_type="sft", + load_path=test_dir, + window_size=64, ) - + assert sft_dataset is not None - assert hasattr(sft_dataset, 'fetcher') + assert hasattr(sft_dataset, "fetcher") assert len(sft_dataset) > 0 - + # Test that we can get SFT items without errors for i in range(min(3, len(sft_dataset))): item = sft_dataset[i] @@ -114,33 +116,30 @@ def test_sft_dataset_with_random_data(base_test_env): def test_dataset_with_custom_stride(base_test_env): """Test dataset with custom stride parameter""" test_dir = base_test_env["test_dir"] - + # Create test data seq_length = 200 dummy_data = { "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], } - - save_h5(test_dir,"stride_test_data", dummy_data) - + + save_h5(test_dir, "stride_test_data", dummy_data) + # Test with custom stride custom_stride = 32 dataset = DatasetLoader.load( - train_type="seq", - load_path=test_dir, - window_size=64, - stride=custom_stride + train_type="seq", load_path=test_dir, window_size=64, stride=custom_stride ) - + assert dataset is not None assert len(dataset) > 0 - + # With stride 32 and window 64 on 200 length data, we should get more samples # than with default stride (which equals window size) default_stride_dataset = DatasetLoader.load( - train_type="seq", - load_path=test_dir, - window_size=64, + train_type="seq", + load_path=test_dir, + window_size=64, ) - + assert len(dataset) > len(default_stride_dataset) diff --git a/tests/data/test_sampler.py b/tests/data/test_sampler.py index 3876ebd..94f7a77 100644 --- a/tests/data/test_sampler.py +++ b/tests/data/test_sampler.py @@ -1,30 +1,32 @@ from khaosz.trainer import * from khaosz.data import * + def test_random_sampler_consistency(random_dataset): """Test RandomSampler produces consistent results with same seed""" dataset = random_dataset - + # Create two samplers with same seed sampler1 = ResumableDistributedSampler(dataset, seed=42) sampler2 = ResumableDistributedSampler(dataset, seed=42) - + indices1 = list(iter(sampler1)) indices2 = list(iter(sampler2)) - + assert indices1 == indices2 + def test_random_sampler_different_seeds(random_dataset): """Test RandomSampler produces different results with different seeds""" dataset = random_dataset - + # Create two samplers with different seeds sampler1 = ResumableDistributedSampler(dataset, seed=42) sampler2 = ResumableDistributedSampler(dataset, seed=123) - + indices1 = list(iter(sampler1)) indices2 = list(iter(sampler2)) - + # Very high probability they should be different assert indices1 != indices2 @@ -33,20 +35,20 @@ def test_sampler_across_epochs(random_dataset): """Test sampler behavior across multiple epochs""" dataset = random_dataset n = len(dataset) - + sampler = ResumableDistributedSampler(dataset, seed=42) - + # Get indices for first epoch epoch1_indices = list(iter(sampler)) assert len(epoch1_indices) == n - + # Get indices for second epoch epoch2_indices = list(iter(sampler)) assert len(epoch2_indices) == n - + # Check that epochs have different order (should be random) assert epoch1_indices != epoch2_indices - + # Check that all indices are present in each epoch assert set(epoch1_indices) == set(range(n)) - assert set(epoch2_indices) == set(range(n)) \ No newline at end of file + assert set(epoch2_indices) == set(range(n)) diff --git a/tests/module/test_module.py b/tests/module/test_module.py index dce5508..03c60b3 100644 --- a/tests/module/test_module.py +++ b/tests/module/test_module.py @@ -12,6 +12,7 @@ from khaosz.data import * from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore from tokenizers import pre_tokenizers + @pytest.fixture def test_env(request: pytest.FixtureRequest): func_name = request.function.__name__ @@ -19,7 +20,7 @@ def test_env(request: pytest.FixtureRequest): config_path = os.path.join(test_dir, "config.json") tokenizer_path = os.path.join(test_dir, "tokenizer.json") model_path = os.path.join(test_dir, "model.safetensors") - + config = { "vocab_size": 1000, "dim": 128, @@ -28,20 +29,20 @@ def test_env(request: pytest.FixtureRequest): "dim_ffn": 256, "max_len": 64, "n_layers": 2, - "norm_eps": 1e-5 + "norm_eps": 1e-5, } - with open(config_path, 'w') as f: + with open(config_path, "w") as f: json.dump(config, f) - + tokenizer = BpeTokenizer() sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet()) tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1) tokenizer.save(tokenizer_path) - + transformer_config = ModelConfig().load(config_path) model = Transformer(transformer_config) st.save_file(model.state_dict(), model_path) - + yield { "test_dir": test_dir, "model": model, @@ -51,47 +52,55 @@ def test_env(request: pytest.FixtureRequest): shutil.rmtree(test_dir) + def test_model_parameter(test_env): save_dir = os.path.join(test_env["test_dir"], "save") - model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"]) + model_param = ModelParameter( + test_env["model"], test_env["tokenizer"], test_env["transformer_config"] + ) ModelParameter.save(model_param, save_dir) - + assert os.path.exists(os.path.join(save_dir, "model.safetensors")) assert os.path.exists(os.path.join(save_dir, "tokenizer.json")) assert os.path.exists(os.path.join(save_dir, "config.json")) + # transformer def test_transformer(test_env): model = test_env["model"] - input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, - (4, test_env["transformer_config"].max_len)) + input_ids = torch.randint( + 0, + test_env["transformer_config"].vocab_size, + (4, test_env["transformer_config"].max_len), + ) output_logits = model(input_ids)["logits"] - target_shape = (4, test_env["transformer_config"].max_len, test_env["transformer_config"].vocab_size) + target_shape = ( + 4, + test_env["transformer_config"].max_len, + test_env["transformer_config"].vocab_size, + ) assert output_logits.shape == target_shape - + + # generator def test_embedding_encoder_core(test_env): parameter = ModelParameter( - test_env["model"], - test_env["tokenizer"], - test_env["transformer_config"] + test_env["model"], test_env["tokenizer"], test_env["transformer_config"] ) encoder = EmbeddingEncoderCore(parameter) - + single_emb = encoder.encode("测试文本") assert isinstance(single_emb, torch.Tensor) assert single_emb.shape[-1] == test_env["transformer_config"].dim - batch_emb = encoder.encode(["测试1", "测试2"]) assert isinstance(batch_emb, list) assert len(batch_emb) == 2 - + + def test_generator_core(test_env): parameter = ModelParameter( - test_env["model"], - test_env["tokenizer"], - test_env["transformer_config"] + test_env["model"], test_env["tokenizer"], test_env["transformer_config"] ) generator = GeneratorCore(parameter) input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10)) @@ -102,8 +111,8 @@ def test_generator_core(test_env): top_p=0.95, attn_mask=None, kv_caches=None, - start_pos=0 + start_pos=0, ) - + assert next_token_id.shape == (4, 1) assert cache_increase == 10 diff --git a/tests/module/test_tie_weight.py b/tests/module/test_tie_weight.py index 2542f0a..221f180 100644 --- a/tests/module/test_tie_weight.py +++ b/tests/module/test_tie_weight.py @@ -13,7 +13,7 @@ def transformer_test_env(): """创建Transformer测试专用环境""" test_dir = tempfile.mkdtemp(prefix="transformer_test_") config_path = os.path.join(test_dir, "config.json") - + config = { "vocab_size": 1000, "dim": 128, @@ -22,18 +22,14 @@ def transformer_test_env(): "dim_ffn": 256, "max_len": 64, "n_layers": 2, - "norm_eps": 1e-5 + "norm_eps": 1e-5, } - - with open(config_path, 'w') as f: + + with open(config_path, "w") as f: json.dump(config, f) - - yield { - "test_dir": test_dir, - "config_path": config_path, - "config": config - } - + + yield {"test_dir": test_dir, "config_path": config_path, "config": config} + if os.path.exists(test_dir): try: for file in os.listdir(test_dir): @@ -46,74 +42,75 @@ def transformer_test_env(): def test_tie_weight_init(transformer_test_env): config_path = transformer_test_env["config_path"] config_data = transformer_test_env["config"].copy() - + # case 1: tie weight config_data["tie_weight"] = True - - with open(config_path, 'w') as f: + + with open(config_path, "w") as f: json.dump(config_data, f) - + config = ModelConfig().load(config_path) model = Transformer(config) - + assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr() - + original_weight = model.embed_tokens.weight.clone() model.embed_tokens.weight.data[0, 0] = 100.0 - + assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert not torch.equal(model.lm_head.weight, original_weight) - + # case 2: not tie weight config_data["tie_weight"] = False - - with open(config_path, 'w') as f: + + with open(config_path, "w") as f: json.dump(config_data, f) - + config = ModelConfig().load(config_path) model = Transformer(config) - + assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr() original_weight = model.embed_tokens.weight.clone() model.embed_tokens.weight.data[0, 0] = 100.0 - + assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert not torch.equal(model.lm_head.weight, original_weight) + def test_model_save_load_with_tie_weight(transformer_test_env): test_dir = transformer_test_env["test_dir"] model_path = os.path.join(test_dir, "model.safetensors") - + config_data = transformer_test_env["config"].copy() - + # case 1: tie weight config_data["tie_weight"] = True config_path = os.path.join(test_dir, "config.json") - - with open(config_path, 'w') as f: + + with open(config_path, "w") as f: json.dump(config_data, f) - + config = ModelConfig().load(config_path) original_model = Transformer(config) - + st.save_file(original_model.state_dict(), model_path) loaded_config = ModelConfig().load(config_path) model = Transformer(loaded_config) model.load_state_dict(st.load_file(model_path)) - + assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr() assert "lm_head.weight" not in model.state_dict() # case 2: not tie weight (form tie-weight state dict load) config_data["tie_weight"] = False - with open(config_path, 'w') as f: + with open(config_path, "w") as f: json.dump(config_data, f) - + loaded_config = ModelConfig().load(config_path) model = Transformer(loaded_config) model.load_state_dict(st.load_file(model_path)) @@ -121,4 +118,3 @@ def test_model_save_load_with_tie_weight(transformer_test_env): assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr() assert "lm_head.weight" in model.state_dict() - \ No newline at end of file diff --git a/tests/test_parallel.py b/tests/test_parallel.py index 714cc74..c5207b9 100644 --- a/tests/test_parallel.py +++ b/tests/test_parallel.py @@ -1,16 +1,14 @@ import torch import torch.distributed as dist -from khaosz.parallel import ( - get_rank, - only_on_rank, - spawn_parallel_fn -) +from khaosz.parallel import get_rank, only_on_rank, spawn_parallel_fn + @only_on_rank(0) def _test_only_on_rank_helper(): return True + def only_on_rank(): result = _test_only_on_rank_helper() if get_rank() == 0: @@ -18,22 +16,17 @@ def only_on_rank(): else: assert result is None + def all_reduce(): x = torch.tensor([get_rank()], dtype=torch.int) dist.all_reduce(x, op=dist.ReduceOp.SUM) expected_sum = sum(range(dist.get_world_size())) assert x.item() == expected_sum + def test_spawn_only_on_rank(): - spawn_parallel_fn( - only_on_rank, - world_size=2, - backend="gloo" - ) + spawn_parallel_fn(only_on_rank, world_size=2, backend="gloo") + def test_spawn_all_reduce(): - spawn_parallel_fn( - all_reduce, - world_size=2, - backend="gloo" - ) \ No newline at end of file + spawn_parallel_fn(all_reduce, world_size=2, backend="gloo") diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index 0ae63d8..08fef7b 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -3,57 +3,48 @@ import torch from khaosz.config import * from khaosz.trainer import * + def test_callback_integration(base_test_env, random_dataset): """Test that all callbacks are properly integrated""" - schedule_config = CosineScheduleConfig( - warmup_steps=10, - total_steps=20 - ) - + schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) + optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) - + train_config = TrainConfig( model=base_test_env["model"], - strategy='seq', + strategy="seq", dataset=random_dataset, optimizer_fn=optimizer_fn, scheduler_fn=scheduler_fn, - checkpoint_dir=base_test_env["test_dir"], + ckpt_dir=base_test_env["test_dir"], n_epoch=1, batch_size=2, - checkpoint_interval=3, + ckpt_interval=3, accumulation_steps=1, max_grad_norm=1.0, random_seed=42, - device_type=base_test_env["device"] + device_type=base_test_env["device"], ) - - # Create custom callbacks to track calls callback_calls = [] - + class TrackingCallback(TrainCallback): def on_train_begin(self, context): - callback_calls.append('on_train_begin') - - def on_batch_end(self, context): - callback_calls.append('on_batch_end') - - def on_epoch_end(self, context): - callback_calls.append('on_epoch_end') - + callback_calls.append("on_train_begin") + + def on_batch_end(self, context): + callback_calls.append("on_batch_end") + + def on_epoch_end(self, context): + callback_calls.append("on_epoch_end") + + trainer = Trainer(train_config, callbacks=[TrackingCallback()]) - - trainer = Trainer( - train_config, - callbacks=[TrackingCallback()] - ) - trainer.train() - + # Verify callbacks were called - assert 'on_train_begin' in callback_calls - assert 'on_batch_end' in callback_calls - assert 'on_epoch_end' in callback_calls \ No newline at end of file + assert "on_train_begin" in callback_calls + assert "on_batch_end" in callback_calls + assert "on_epoch_end" in callback_calls diff --git a/tests/trainer/test_early_stopping.py b/tests/trainer/test_early_stopping.py index dc794f4..a9a244a 100644 --- a/tests/trainer/test_early_stopping.py +++ b/tests/trainer/test_early_stopping.py @@ -5,31 +5,32 @@ from khaosz.config import * from khaosz.trainer import * from khaosz.data.serialization import Checkpoint + def test_early_stopping_simulation(base_test_env, early_stopping_dataset): """Simulate early stopping behavior""" - + schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) - + optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) - + train_config = TrainConfig( strategy="seq", optimizer_fn=optimizer_fn, scheduler_fn=scheduler_fn, model=base_test_env["model"], dataset=early_stopping_dataset, - checkpoint_dir=base_test_env["test_dir"], + ckpt_dir=base_test_env["test_dir"], n_epoch=2, batch_size=2, - checkpoint_interval=1, + ckpt_interval=1, accumulation_steps=2, random_seed=np.random.randint(1e4), - device_type=base_test_env["device"] + device_type=base_test_env["device"], ) trainer = Trainer(train_config) - + # Should handle early stopping gracefully checkpoint = None try: @@ -37,11 +38,11 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset): except Exception: # Handle any exceptions pass - + load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2") checkpoint = Checkpoint.load(load_dir) trainer.train(checkpoint) - + load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10") checkpoint = Checkpoint.load(load_dir) - assert checkpoint.iteration == 10 \ No newline at end of file + assert checkpoint.iteration == 10 diff --git a/tests/trainer/test_train_strategy.py b/tests/trainer/test_train_strategy.py index 7f96294..ac7ad30 100644 --- a/tests/trainer/test_train_strategy.py +++ b/tests/trainer/test_train_strategy.py @@ -9,39 +9,41 @@ from khaosz.data.dataset import * def test_schedule_factory_random_configs(): """Test scheduler factory with random configurations""" - + # Create a simple model and optimizer for testing model = torch.nn.Linear(10, 2) optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) - + # Test multiple random configurations for _ in range(5): # Test 5 random configurations schedule_configs = [ CosineScheduleConfig( warmup_steps=np.random.randint(50, 200), total_steps=np.random.randint(1000, 5000), - min_rate=np.random.uniform(0.01, 0.1) + min_rate=np.random.uniform(0.01, 0.1), ), SGDRScheduleConfig( warmup_steps=np.random.randint(50, 200), cycle_length=np.random.randint(500, 2000), t_mult=np.random.randint(1, 3), - min_rate=np.random.uniform(0.01, 0.1) - ) + min_rate=np.random.uniform(0.01, 0.1), + ), ] - + for config in schedule_configs: # Validate configuration config.validate() - + # Create scheduler using factory scheduler = SchedulerFactory.load(optimizer, config) - + # Verify scheduler type if isinstance(config, CosineScheduleConfig): assert isinstance(scheduler, CosineScheduler) assert scheduler.warmup_steps == config.warmup_steps - assert scheduler.lr_decay_steps == config.total_steps - config.warmup_steps + assert ( + scheduler.lr_decay_steps == config.total_steps - config.warmup_steps + ) assert scheduler.min_rate == config.min_rate elif isinstance(config, SGDRScheduleConfig): assert isinstance(scheduler, SGDRScheduler) @@ -49,17 +51,17 @@ def test_schedule_factory_random_configs(): assert scheduler.cycle_length == config.cycle_length assert scheduler.t_mult == config.t_mult assert scheduler.min_rate == config.min_rate - + # Test scheduler state dict functionality state_dict = scheduler.state_dict() - assert 'warmup_steps' in state_dict - assert 'min_rate' in state_dict - + assert "warmup_steps" in state_dict + assert "min_rate" in state_dict + # Test scheduler step functionality initial_lr = scheduler.get_last_lr() scheduler.step() new_lr = scheduler.get_last_lr() - + # Learning rate should change after step, or if it's the first step, # the epoch counter should increment assert initial_lr != new_lr or scheduler.last_epoch > -1 @@ -67,10 +69,10 @@ def test_schedule_factory_random_configs(): def test_schedule_factory_edge_cases(): """Test scheduler factory with edge cases and boundary conditions""" - + model = torch.nn.Linear(10, 2) optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) - + # Test edge cases for CosineScheduleConfig edge_cases = [ # Minimal warmup and steps @@ -80,12 +82,12 @@ def test_schedule_factory_edge_cases(): # Zero min_rate (edge case) CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.0), ] - + for config in edge_cases: config.validate() scheduler = SchedulerFactory.load(optimizer, config) assert scheduler is not None - + # Test multiple steps for _ in range(10): scheduler.step() @@ -93,7 +95,7 @@ def test_schedule_factory_edge_cases(): def test_schedule_factory_invalid_configs(): """Test scheduler factory with invalid configurations""" - + # Test invalid configurations that should raise errors invalid_configs = [ # Negative warmup steps @@ -104,7 +106,7 @@ def test_schedule_factory_invalid_configs(): {"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1}, {"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1}, ] - + for kwargs in invalid_configs: with pytest.raises(ValueError): config = CosineScheduleConfig(**kwargs) @@ -113,24 +115,24 @@ def test_schedule_factory_invalid_configs(): def test_schedule_factory_state_persistence(): """Test scheduler state persistence (save/load)""" - + model = torch.nn.Linear(10, 2) optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) - + config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1) scheduler = SchedulerFactory.load(optimizer, config) - + # Take a few steps for _ in range(5): scheduler.step() - + # Save state state_dict = scheduler.state_dict() - + # Create new scheduler and load state new_scheduler = SchedulerFactory.load(optimizer, config) new_scheduler.load_state_dict(state_dict) - + # Verify states match assert scheduler.last_epoch == new_scheduler.last_epoch - assert scheduler.get_last_lr() == new_scheduler.get_last_lr() \ No newline at end of file + assert scheduler.get_last_lr() == new_scheduler.get_last_lr() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ce93bec..a0cca66 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -6,100 +6,94 @@ from khaosz.config import * from khaosz.trainer import * from khaosz.data.dataset import * + def test_different_batch_sizes(base_test_env, random_dataset): """Test training with different batch sizes""" batch_sizes = [1, 2, 4, 8] - + for batch_size in batch_sizes: - schedule_config = CosineScheduleConfig( - warmup_steps=10, - total_steps=20 - ) + schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) - + train_config = TrainConfig( strategy="seq", model=base_test_env["model"], dataset=random_dataset, optimizer_fn=optimizer_fn, scheduler_fn=scheduler_fn, - checkpoint_dir=base_test_env["test_dir"], + ckpt_dir=base_test_env["test_dir"], n_epoch=1, batch_size=batch_size, - checkpoint_interval=5, + ckpt_interval=5, accumulation_steps=1, max_grad_norm=1.0, random_seed=np.random.randint(1000), - device_type=base_test_env["device"] + device_type=base_test_env["device"], ) - + assert train_config.batch_size == batch_size + def test_gradient_accumulation(base_test_env, random_dataset): """Test training with different gradient accumulation steps""" accumulation_steps_list = [1, 2, 4] - + for accumulation_steps in accumulation_steps_list: - schedule_config = CosineScheduleConfig( - warmup_steps=10, - total_steps=20 - ) + schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) - + train_config = TrainConfig( strategy="seq", model=base_test_env["model"], optimizer_fn=optimizer_fn, scheduler_fn=scheduler_fn, dataset=random_dataset, - checkpoint_dir=base_test_env["test_dir"], + ckpt_dir=base_test_env["test_dir"], n_epoch=1, batch_size=2, - checkpoint_interval=10, + ckpt_interval=10, accumulation_steps=accumulation_steps, max_grad_norm=1.0, random_seed=42, - device_type=base_test_env["device"] + device_type=base_test_env["device"], ) - + trainer = Trainer(train_config) trainer.train() - + assert train_config.accumulation_steps == accumulation_steps + def test_memory_efficient_training(base_test_env, random_dataset): """Test training with memory-efficient configurations""" # Test with smaller batch sizes and gradient checkpointing small_batch_configs = [ {"batch_size": 1, "accumulation_steps": 8}, {"batch_size": 2, "accumulation_steps": 4}, - {"batch_size": 4, "accumulation_steps": 2} + {"batch_size": 4, "accumulation_steps": 2}, ] - + for config in small_batch_configs: - schedule_config = CosineScheduleConfig( - warmup_steps=10, - total_steps=20 - ) + schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) - + train_config = TrainConfig( strategy="seq", model=base_test_env["model"], dataset=random_dataset, optimizer_fn=optimizer_fn, scheduler_fn=scheduler_fn, - checkpoint_dir=base_test_env["test_dir"], + ckpt_dir=base_test_env["test_dir"], n_epoch=1, batch_size=config["batch_size"], - checkpoint_interval=5, + ckpt_interval=5, accumulation_steps=config["accumulation_steps"], max_grad_norm=1.0, random_seed=42, - device_type=base_test_env["device"] + device_type=base_test_env["device"], ) - - assert train_config.accumulation_steps == config["accumulation_steps"] \ No newline at end of file + + assert train_config.accumulation_steps == config["accumulation_steps"] diff --git a/tools/benchmark.py b/tools/benchmark.py index 5f73667..ebfd7af 100644 --- a/tools/benchmark.py +++ b/tools/benchmark.py @@ -17,41 +17,47 @@ class GenerationBenchmark: self, config: ModelConfig, device: str = "cuda", - dtype: torch.dtype = torch.float16 + dtype: torch.dtype = torch.float16, ): self.config = config self.device = device self.dtype = dtype self.model = Transformer(config).to(device=device, dtype=dtype) self.model.eval() - + def _initialize_kv_cache(self, batch_size: int) -> list: """初始化KV缓存""" config = self.config - shape = (batch_size, config.max_len, config.n_layers, config.n_kv_heads, config.dim // config.n_heads) + shape = ( + batch_size, + config.max_len, + config.n_layers, + config.n_kv_heads, + config.dim // config.n_heads, + ) k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) return (k_cache, v_cache) - + def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int): prompt_ids = torch.randint( low=0, high=self.config.vocab_size, size=(batch_size, prompt_length), device=self.device, - dtype=torch.long + dtype=torch.long, ) - + gen_ids = torch.randint( low=0, high=self.config.vocab_size, size=(batch_size, total_length - prompt_length), device=self.device, - dtype=torch.long + dtype=torch.long, ) - + return prompt_ids, gen_ids - + @torch.inference_mode() def run_prefill_benchmark( self, @@ -59,32 +65,38 @@ class GenerationBenchmark: prompt_length: int = 512, num_trials: int = 10, ) -> BenchmarkResult: - + for _ in range(3): - prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length) + prompt_ids, _ = self._prepare_inputs( + batch_size, prompt_length, prompt_length + ) _ = self.model(prompt_ids) - + torch.cuda.synchronize() - + total_time = 0.0 total_tokens = batch_size * prompt_length * num_trials - + for trial in range(num_trials): - prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length) + prompt_ids, _ = self._prepare_inputs( + batch_size, prompt_length, prompt_length + ) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - + start_event.record() _ = self.model(prompt_ids) end_event.record() torch.cuda.synchronize() - + trial_time = start_event.elapsed_time(end_event) / 1000 total_time += trial_time - - print(f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s " - f"({prompt_length / trial_time:.1f} tokens/s)") - + + print( + f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s " + f"({prompt_length / trial_time:.1f} tokens/s)" + ) + return BenchmarkResult( total_tokens=total_tokens, total_time=total_time, @@ -95,9 +107,9 @@ class GenerationBenchmark: "prompt_length": prompt_length, "dtype": self.dtype, "device": self.device, - } + }, ) - + @torch.inference_mode() def run_decoding_benchmark( self, @@ -106,39 +118,43 @@ class GenerationBenchmark: gen_length: int = 128, num_trials: int = 5, ) -> BenchmarkResult: - + total_time = 0.0 total_tokens = batch_size * gen_length * num_trials - + for trial in range(num_trials): - - prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length) + prompt_ids, gen_ids = self._prepare_inputs( + batch_size, prompt_length, prompt_length + gen_length + ) kv_cache = self._initialize_kv_cache(batch_size) _ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0) - + torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - + start_event.record() - + current_pos = prompt_length for i in range(gen_length): - input_token = gen_ids[:, i:i+1] - _ = self.model(input_token, persistent_key_values=kv_cache, start_pos=current_pos) + input_token = gen_ids[:, i : i + 1] + _ = self.model( + input_token, persistent_key_values=kv_cache, start_pos=current_pos + ) current_pos += 1 - + end_event.record() torch.cuda.synchronize() - + trial_time = start_event.elapsed_time(end_event) / 1000 total_time += trial_time - - print(f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s " - f"({gen_length / trial_time:.1f} tokens/s)") - + print( + f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s " + f"({gen_length / trial_time:.1f} tokens/s)" + ) + return BenchmarkResult( total_tokens=total_tokens, total_time=total_time, @@ -150,24 +166,28 @@ class GenerationBenchmark: "gen_length": gen_length, "dtype": self.dtype, "device": self.device, - } + }, ) def print_benchmark_result(result: BenchmarkResult): """打印基准测试结果""" benchmark_type = result.metadata["benchmark_type"] - + print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}") print(f"Total Tokens Processed: {result.total_tokens:,}") print(f"Time Consumed: {result.total_time:.3f}s") print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s") - + if benchmark_type == "prefill": - print(f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}") + print( + f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}" + ) elif benchmark_type == "decoding": - print(f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}") - + print( + f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}" + ) + print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}") print("-" * 80) @@ -183,16 +203,19 @@ if __name__ == "__main__": n_layers=24, norm_eps=1e-5, ) - + benchmark = GenerationBenchmark(config) - + print("=" * 80) print("Running Transformer Generation Benchmark") print("=" * 80) - - prefill_result = benchmark.run_prefill_benchmark(batch_size=4, prompt_length=512, num_trials=5) + + prefill_result = benchmark.run_prefill_benchmark( + batch_size=4, prompt_length=512, num_trials=5 + ) print_benchmark_result(prefill_result) - - gen_result = benchmark.run_decoding_benchmark(batch_size=4, prompt_length=512, gen_length=128, num_trials=5) + + gen_result = benchmark.run_decoding_benchmark( + batch_size=4, prompt_length=512, gen_length=128, num_trials=5 + ) print_benchmark_result(gen_result) - \ No newline at end of file diff --git a/tools/generate.py b/tools/generate.py index d5c96b0..22f1dc6 100644 --- a/tools/generate.py +++ b/tools/generate.py @@ -21,10 +21,10 @@ def processor( with disable_random_init(): param = ModelParameter.load(model_dir) - param.to(device='cuda', dtype=torch.bfloat16) + param.to(device="cuda", dtype=torch.bfloat16) generator = BatchGenerator(param) - with open(input_json_file, "r", encoding='utf-8') as f: + with open(input_json_file, "r", encoding="utf-8") as f: input_data = [json.loads(line) for line in f] queries = [item[question_key] for item in input_data] @@ -41,26 +41,62 @@ def processor( responses = generator.generate(request) - with open(output_json_file, "w", encoding='utf-8') as f: + with open(output_json_file, "w", encoding="utf-8") as f: for query, response in zip(queries, responses): output_item = {question_key: query, response_key: response} - f.write(json.dumps(output_item, ensure_ascii=False) + '\n') + f.write(json.dumps(output_item, ensure_ascii=False) + "\n") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.") - - parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.") - parser.add_argument("--input_json_file", type=str, required=True, help="Path to the input JSONL file.") - parser.add_argument("--output_json_file", type=str, required=True, help="Path to the output JSONL file.") - parser.add_argument("--question_key", type=str, default="question", help="Key for the question in the input JSON.") - parser.add_argument("--response_key", type=str, default="response", help="Key for the response in the output JSON.") - parser.add_argument("--temperature", type=float, default=0.60, help="Temperature for generating responses.") - parser.add_argument("--top_k", type=int, default=30, help="Top-k value for generating responses.") - parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.") - parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.") + + parser.add_argument( + "--model_dir", type=str, required=True, help="Path to the model directory." + ) + parser.add_argument( + "--input_json_file", + type=str, + required=True, + help="Path to the input JSONL file.", + ) + parser.add_argument( + "--output_json_file", + type=str, + required=True, + help="Path to the output JSONL file.", + ) + parser.add_argument( + "--question_key", + type=str, + default="question", + help="Key for the question in the input JSON.", + ) + parser.add_argument( + "--response_key", + type=str, + default="response", + help="Key for the response in the output JSON.", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.60, + help="Temperature for generating responses.", + ) + parser.add_argument( + "--top_k", type=int, default=30, help="Top-k value for generating responses." + ) + parser.add_argument( + "--top_p", + type=float, + default=0.95, + help="Top-p value for generating responses.", + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size for generating responses." + ) args = parser.parse_args() - + with torch.inference_mode(): - processor(**vars(args)) \ No newline at end of file + processor(**vars(args)) diff --git a/tools/perplexity.py b/tools/perplexity.py index 2a9cc65..a91d370 100644 --- a/tools/perplexity.py +++ b/tools/perplexity.py @@ -11,89 +11,99 @@ from khaosz.inference.core import disable_random_init def compute_perplexity( - model: nn.Module, - input_ids: Tensor, - input_mask: Tensor, - ) -> Tensor: + model: nn.Module, + input_ids: Tensor, + input_mask: Tensor, +) -> Tensor: """ - Compute the perplexity of a batch of input sequences, - where PPL = exp(-(1/N) * sum(log P(w_i | w_ argparse.Namespace: parser = argparse.ArgumentParser(description="Train the Transformer model.") - - parser.add_argument("--train_type", type=str, required=True, choices=["seq", "sft", "dpo"], help="Train type.") - parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.") - parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.") - - parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.") - parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.") - parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.") - parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.") - parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.") - parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.") - parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.") - parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.") - parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.") - parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") - parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.") - parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory") - parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.") - parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.") + + parser.add_argument( + "--train_type", + type=str, + required=True, + choices=["seq", "sft", "dpo"], + help="Train type.", + ) + parser.add_argument( + "--data_root_path", + type=str, + required=True, + help="Path to the root directory of the dataset.", + ) + parser.add_argument( + "--param_path", + type=str, + required=True, + help="Path to the model parameters or resume checkpoint.", + ) + + parser.add_argument( + "--n_epoch", type=int, default=1, help="Number of epochs to train." + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size for training." + ) + parser.add_argument( + "--accumulation_steps", + type=int, + default=1, + help="Number of iterations between each optimizer step.", + ) + parser.add_argument( + "--warmup_steps", + type=int, + default=1000, + help="Number of iters between warnings.", + ) + parser.add_argument( + "--max_lr", type=float, default=3e-4, help="Max learning rate for training." + ) + parser.add_argument( + "--max_grad_norm", + type=float, + default=1.0, + help="Max gradient norm for clipping.", + ) + parser.add_argument( + "--adamw_beta1", + type=float, + default=0.9, + help="Beta values for AdamW optimizer.", + ) + parser.add_argument( + "--adamw_beta2", + type=float, + default=0.95, + help="Beta values for AdamW optimizer.", + ) + parser.add_argument( + "--adamw_weight_decay", + type=float, + default=0.01, + help="Weight decay for AdamW optimizer.", + ) + parser.add_argument( + "--random_seed", type=int, default=3407, help="Random seed for reproducibility." + ) + parser.add_argument( + "--num_workers", type=int, default=4, help="Number of workers for data loading." + ) + parser.add_argument( + "--no_pin_memory", + action="store_false", + dest="pin_memory", + help="Disable pin memory", + ) + parser.add_argument( + "--window_size", + type=int, + default=None, + help="the max length of the input sequence.", + ) + parser.add_argument( + "--stride", type=int, default=None, help="the step size of the input sequence." + ) parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") - parser.add_argument("--label_smoothing", type=float, default=0.1, help="cross_entropy function label smoothing parameter") - - parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.") - parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.") - parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.") - parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.") - + parser.add_argument( + "--label_smoothing", + type=float, + default=0.1, + help="cross_entropy function label smoothing parameter", + ) + + parser.add_argument( + "--ckpt_interval", + type=int, + default=5000, + help="Number of iters between checkpoints.", + ) + parser.add_argument( + "--ckpt_dir", + type=str, + default="checkpoint", + help="Directory to save checkpoints.", + ) + parser.add_argument( + "--start_epoch", type=int, default=0, help="Start epoch for training." + ) + parser.add_argument( + "--start_batch", type=int, default=0, help="Start batch for training." + ) + parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.") - parser.add_argument("--device_type", type=str, default="cuda", help="Device type to use.") - + parser.add_argument( + "--device_type", type=str, default="cuda", help="Device type to use." + ) + args = parser.parse_args() return args + def ddp_wrap(model: nn.Module): local_rank = get_rank() model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16) @@ -56,16 +146,21 @@ def ddp_wrap(model: nn.Module): model, device_ids=[local_rank], output_device=local_rank, - find_unused_parameters=False + find_unused_parameters=False, ) return ddp_model + def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer: return optim.AdamW(model.parameters(), **kwargs) -def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler: + +def create_scheduler( + optimizer: optim.Optimizer, **kwargs +) -> optim.lr_scheduler.LRScheduler: return SchedulerFactory.load(optimizer, **kwargs) + def prepare_checkpoint(model: nn.Module) -> dict: return model.module.state_dict() @@ -81,8 +176,8 @@ def train( start_batch: int, accumulation_steps: int, warmup_steps: int, - checkpoint_interval: int, - checkpoint_dir: str, + ckpt_interval: int, + ckpt_dir: str, dpo_beta: float, adamw_beta1: float, adamw_beta2: float, @@ -99,48 +194,50 @@ def train( ): assert train_type in ["seq", "sft", "dpo"] assert os.path.exists(param_path) - + parameter = ModelParameter.load(param_path) if window_size is None: window_size = parameter.config.max_len model = parameter.model - - strategy_kwargs = { - "dpo_beta": dpo_beta, - "label_smoothing": label_smoothing - } + + strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing} dataset = DatasetLoader.load( train_type=train_type, load_path=data_root_path, window_size=window_size, - stride=stride - ) - - schedule_config = CosineScheduleConfig( - warmup_steps=warmup_steps, - total_steps=len(dataset) * n_epoch // (batch_size * nprocs), + stride=stride, ) - optimizer_fn = partial(create_optimizer, - **{"lr": max_lr, "betas": (adamw_beta1, adamw_beta2), "weight_decay": adamw_weight_decay}) - scheduler_fn = partial(create_scheduler, - **{"schedule_config": schedule_config}) - + schedule_config = CosineScheduleConfig( + warmup_steps=warmup_steps, + total_steps=len(dataset) * n_epoch // (batch_size * nprocs), + ) + + optimizer_fn = partial( + create_optimizer, + **{ + "lr": max_lr, + "betas": (adamw_beta1, adamw_beta2), + "weight_decay": adamw_weight_decay, + }, + ) + scheduler_fn = partial(create_scheduler, **{"schedule_config": schedule_config}) + train_config = TrainConfig( model=model, strategy=train_type, dataset=dataset, optimizer_fn=optimizer_fn, scheduler_fn=scheduler_fn, - checkpoint_dir=checkpoint_dir, + ckpt_dir=ckpt_dir, n_epoch=n_epoch, batch_size=batch_size, start_epoch=start_epoch, start_batch=start_batch, - checkpoint_interval=checkpoint_interval, + ckpt_interval=ckpt_interval, accumulation_steps=accumulation_steps, max_grad_norm=max_grad_norm, random_seed=random_seed, @@ -152,11 +249,11 @@ def train( device_type=device_type, extra_kwargs=strategy_kwargs, ) - + trainer = Trainer(train_config) trainer.train() - + if __name__ == "__main__": args = parse_args() - train(**vars(args)) \ No newline at end of file + train(**vars(args))