From c51b203fde1cc27b165b9e7a922b39187c2f0188 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 18 Oct 2025 13:56:59 +0800 Subject: [PATCH] =?UTF-8?q?refactor(khaosz):=20=E9=87=8D=E6=9E=84=E9=A1=B9?= =?UTF-8?q?=E7=9B=AE=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- benchmark.py | 2 +- khaosz/__init__.py | 24 +-- khaosz/config/__init__.py | 12 ++ khaosz/config/model_config.py | 36 +++++ .../parameter.py => config/param_config.py} | 19 +-- khaosz/{trainer => config}/train_config.py | 8 +- khaosz/core/__init__.py | 27 ---- khaosz/data/__init__.py | 30 ++++ khaosz/{trainer => data}/data_util.py | 67 +++----- khaosz/{core => data}/tokenizer.py | 0 khaosz/inference/core.py | 97 ++++++++++++ khaosz/{core => inference}/generator.py | 95 +----------- khaosz/{model.py => khaosz.py} | 4 +- khaosz/model/__init__.py | 17 ++ .../{core/transformer.py => model/module.py} | 146 +----------------- khaosz/model/transformer.py | 119 ++++++++++++++ khaosz/trainer/__init__.py | 7 +- khaosz/trainer/train_callback.py | 2 +- khaosz/trainer/train_context.py | 54 +++---- khaosz/trainer/trainer.py | 6 +- tests/conftest.py | 8 +- tests/test_callbacks.py | 3 +- tests/test_dataset_loader.py | 4 +- tests/test_early_stopping.py | 5 +- tests/test_module.py | 7 +- tests/test_sampler.py | 38 +---- tests/test_train_config.py | 5 +- tests/test_train_strategy.py | 4 +- 28 files changed, 423 insertions(+), 423 deletions(-) create mode 100644 khaosz/config/__init__.py create mode 100644 khaosz/config/model_config.py rename khaosz/{core/parameter.py => config/param_config.py} (95%) rename khaosz/{trainer => config}/train_config.py (91%) delete mode 100644 khaosz/core/__init__.py create mode 100644 khaosz/data/__init__.py rename khaosz/{trainer => data}/data_util.py (87%) rename khaosz/{core => data}/tokenizer.py (100%) create mode 100644 khaosz/inference/core.py rename khaosz/{core => inference}/generator.py (80%) rename khaosz/{model.py => khaosz.py} (95%) create mode 100644 khaosz/model/__init__.py rename khaosz/{core/transformer.py => model/module.py} (56%) create mode 100644 khaosz/model/transformer.py diff --git a/benchmark.py b/benchmark.py index 9ca49fe..f2db16e 100644 --- a/benchmark.py +++ b/benchmark.py @@ -1,7 +1,7 @@ import torch from typing import Dict, Any from dataclasses import dataclass -from khaosz.core.transformer import TransformerConfig, Transformer +from khaosz.model.transformer import TransformerConfig, Transformer @dataclass diff --git a/khaosz/__init__.py b/khaosz/__init__.py index 42e8acc..8556926 100644 --- a/khaosz/__init__.py +++ b/khaosz/__init__.py @@ -1,16 +1,23 @@ __version__ = "1.3.0" __author__ = "ViperEkura" -from khaosz.model import Khaosz -from khaosz.core.transformer import Transformer, TransformerConfig +from khaosz.khaosz import Khaosz +from khaosz.config import ( + TransformerConfig, + ParameterLoader, + TrainConfig, +) +from khaosz.model.transformer import Transformer from khaosz.utils.retriever import Retriever from khaosz.utils.splitter import ( SemanticTextSplitter, PriorityTextSplitter ) -from khaosz.core.tokenizer import BpeTokenizer -from khaosz.core.parameter import ParameterLoader -from khaosz.core.generator import ( +from khaosz.data import ( + DatasetLoader, + BpeTokenizer +) +from khaosz.inference.generator import ( TextGenerator, ChatGenerator, StreamGenerator, @@ -18,10 +25,9 @@ from khaosz.core.generator import ( RetrievalGenerator, EmbeddingEncoder ) + from khaosz.trainer import ( Trainer, - DatasetLoader, - TrainConfig, StrategyFactory, SchedulerFactory ) @@ -44,7 +50,7 @@ __all__ = [ # trainer "Trainer", - "DatasetLoader", + "DatasetLoader", # 保持在 __all__ 中,但来源是 khaosz.data "TrainConfig", "StrategyFactory", "SchedulerFactory", @@ -53,4 +59,4 @@ __all__ = [ "Retriever", "SemanticTextSplitter", "PriorityTextSplitter", -] +] \ No newline at end of file diff --git a/khaosz/config/__init__.py b/khaosz/config/__init__.py new file mode 100644 index 0000000..392bdc3 --- /dev/null +++ b/khaosz/config/__init__.py @@ -0,0 +1,12 @@ +from khaosz.config.model_config import TransformerConfig +from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader +from khaosz.config.train_config import TrainConfig + +__all__ = [ + "BaseModelIO", + "ModelParameter", + "Checkpoint", + "ParameterLoader", + "TransformerConfig", + "TrainConfig" +] \ No newline at end of file diff --git a/khaosz/config/model_config.py b/khaosz/config/model_config.py new file mode 100644 index 0000000..b62884a --- /dev/null +++ b/khaosz/config/model_config.py @@ -0,0 +1,36 @@ +import json + +from dataclasses import asdict, dataclass +from typing import Optional, Self + +@dataclass +class TransformerConfig: + # basic config + vocab_size: Optional[int] = None + n_dim: Optional[int] = None + n_head: Optional[int] = None + n_layer: Optional[int] = None + m_len: Optional[int] = None + norm_eps: Optional[float] = None + d_ffn: Optional[int] = None + + # GQA + n_kvhead: Optional[int] = None + + + def load(self, config_path: str) -> Self: + with open(config_path, 'r') as f: + config: dict = json.load(f) + for key, value in config.items(): + if hasattr(self, key): + setattr(self, key, value) + + return self + + def save(self, config_path: str) -> None: + config_dict = asdict(self) + config_dict = {k: v for k, v in config_dict.items() if v is not None} + with open(config_path, 'w') as f: + json.dump(config_dict, f, indent=4) + + diff --git a/khaosz/core/parameter.py b/khaosz/config/param_config.py similarity index 95% rename from khaosz/core/parameter.py rename to khaosz/config/param_config.py index 2daa894..b7b3f05 100644 --- a/khaosz/core/parameter.py +++ b/khaosz/config/param_config.py @@ -8,8 +8,9 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Self, Union from pathlib import Path -from khaosz.core.tokenizer import BpeTokenizer -from khaosz.core.transformer import TransformerConfig, Transformer +from khaosz.data.tokenizer import BpeTokenizer +from khaosz.config.model_config import TransformerConfig +from khaosz.model.transformer import Transformer class BaseModelIO: @@ -99,18 +100,18 @@ class Checkpoint(BaseModelIO): metadata={"help": "Transformer model."} ) tokenizer: BpeTokenizer = field( - default_factory=BpeTokenizer, + default=None, metadata={"help": "Tokenizer for the model."} ) config: TransformerConfig = field( - default_factory=TransformerConfig, + default=None, metadata={"help": "Transformer model configuration."} ) optimizer_state: Dict[str, Any] = field( default=None, metadata={"help": "Optimizer state."} ) - sampler_state: Dict[str, Any] = field( + scheduler_state: Dict[str, Any] = field( default=None, metadata={"help": "Sampler state."} ) @@ -142,10 +143,10 @@ class Checkpoint(BaseModelIO): # Save optimizer state with open(str(paths["optimizer_state"]), "wb") as f: pkl.dump(self.optimizer_state, f) - + # Save sampler state with open(str(paths["sampler_state"]), "wb") as f: - pkl.dump(self.sampler_state, f) + pkl.dump(self.scheduler_state, f) def load_training_state(self, load_dir: Union[str, Path]) -> Self: paths = self._get_training_paths(load_dir) @@ -163,7 +164,7 @@ class Checkpoint(BaseModelIO): # Load sampler state if paths["sampler_state"].exists(): with open(str(paths["sampler_state"]), "rb") as f: - self.sampler_state = pkl.load(f) + self.scheduler_state = pkl.load(f) return self @@ -173,7 +174,7 @@ class Checkpoint(BaseModelIO): return current_iter = len(self.loss_list) - + plt.figure(figsize=(10, 6)) plt.plot(self.loss_list) plt.title(f"Training Loss - Iteration {current_iter}") diff --git a/khaosz/trainer/train_config.py b/khaosz/config/train_config.py similarity index 91% rename from khaosz/trainer/train_config.py rename to khaosz/config/train_config.py index e93e358..6f8162f 100644 --- a/khaosz/trainer/train_config.py +++ b/khaosz/config/train_config.py @@ -1,14 +1,16 @@ from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, TYPE_CHECKING from torch.utils.data import Dataset from torch.optim import Optimizer -from khaosz.trainer.strategy import BaseStrategy + +if TYPE_CHECKING: + from khaosz.trainer.strategy import BaseStrategy @dataclass class TrainConfig: - strategy: BaseStrategy = field( + strategy: "BaseStrategy" = field( default=None, metadata={"help": "Training strategy."} ) diff --git a/khaosz/core/__init__.py b/khaosz/core/__init__.py deleted file mode 100644 index 0b31b94..0000000 --- a/khaosz/core/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -from khaosz.core.tokenizer import BpeTokenizer -from khaosz.core.transformer import Transformer, TransformerConfig -from khaosz.core.parameter import ParameterLoader, ModelParameter, Checkpoint -from khaosz.core.generator import ( - TextGenerator, - ChatGenerator, - StreamGenerator, - BatchGenerator, - RetrievalGenerator, - EmbeddingEncoder -) - - -__all__ = [ - "Transformer", - "TransformerConfig", - "BpeTokenizer", - "ParameterLoader", - "ModelParameter", - "Checkpoint", - "TextGenerator", - "ChatGenerator", - "StreamGenerator", - "BatchGenerator", - "RetrievalGenerator", - "EmbeddingEncoder" -] \ No newline at end of file diff --git a/khaosz/data/__init__.py b/khaosz/data/__init__.py new file mode 100644 index 0000000..10b3f44 --- /dev/null +++ b/khaosz/data/__init__.py @@ -0,0 +1,30 @@ +from khaosz.data.data_util import ( + BaseDataset, + SeqDataset, + DpoDataset, + SftDataset, + PpoDataset, + MutiSegmentFetcher, + ResumeableRandomSampler, + DatasetLoader, + load_pkl_files, + build_attention_mask, + build_loss_mask +) + +from khaosz.data.tokenizer import BpeTokenizer + +__all__ = [ + "BaseDataset", + "SeqDataset", + "DpoDataset", + "SftDataset", + "PpoDataset", + "MutiSegmentFetcher", + "ResumeableRandomSampler", + "DatasetLoader", + "load_pkl_files", + "build_attention_mask", + "build_loss_mask", + "BpeTokenizer" +] \ No newline at end of file diff --git a/khaosz/trainer/data_util.py b/khaosz/data/data_util.py similarity index 87% rename from khaosz/trainer/data_util.py rename to khaosz/data/data_util.py index 7b6570e..19d811c 100644 --- a/khaosz/trainer/data_util.py +++ b/khaosz/data/data_util.py @@ -263,58 +263,39 @@ class DatasetLoader: dataset.load(load_path) return dataset - -class RandomSampler(Sampler[int]): - def __init__(self, data_source, generator=None, seed=42): - self.data_source = data_source - self.seed = seed - self.epoch = 0 - self.current_iter = 0 - self._indices = None + +class ResumeableRandomSampler(Sampler[int]): + def __init__(self, data_source, start_epoch=0, start_iter=0, seed=42): + self.num_samples = len(data_source) + self.epoch = start_epoch + self.iter = start_iter - if generator is None: - self.generator = torch.Generator() - self.generator.manual_seed(seed) - else: - self.generator = generator + generator = torch.Generator() + generator.manual_seed(seed) + + self.generator = generator + self._indices = None - def _generate_indices(self): - n = len(self.data_source) - self._indices = torch.randperm(n, generator=self.generator).tolist() + def _get_indices(self): + for _ in range(self.epoch): + _ = torch.randperm(self.num_samples, generator=self.generator) + + current_epoch_indices = torch.randperm(self.num_samples, generator=self.generator).tolist() + self._indices = current_epoch_indices[self.iter % self.num_samples:] def __iter__(self): - n = len(self.data_source) - if self._indices is None: - self._generate_indices() + self._get_indices() - start = self.current_iter % n - for i in range(start, n): - self.current_iter += 1 - yield self._indices[i] + for i in self._indices: + self.iter += 1 + yield i self.epoch += 1 self._indices = None def __len__(self): - return len(self.data_source) - - def state_dict(self): - return { - 'epoch': self.epoch, - 'current_iter': self.current_iter, - 'seed': self.seed, - 'generator_state': self.generator.get_state() if self.generator else None, - 'indices': self._indices - } - - def load_state_dict(self, state_dict): - self.epoch = state_dict['epoch'] - self.current_iter = state_dict['current_iter'] - self.seed = state_dict['seed'] - - if self.generator and state_dict['generator_state'] is not None: - self.generator.set_state(state_dict['generator_state']) - - self._indices = state_dict['indices'] \ No newline at end of file + if self._indices is None: + self._get_indices() + return len(self._indices) \ No newline at end of file diff --git a/khaosz/core/tokenizer.py b/khaosz/data/tokenizer.py similarity index 100% rename from khaosz/core/tokenizer.py rename to khaosz/data/tokenizer.py diff --git a/khaosz/inference/core.py b/khaosz/inference/core.py new file mode 100644 index 0000000..90732e6 --- /dev/null +++ b/khaosz/inference/core.py @@ -0,0 +1,97 @@ +import torch + +from torch import Tensor +from typing import List, Tuple, Union, Optional, Generator, Self +from khaosz.config.param_config import ModelParameter + + +class GeneratorCore: + def __init__(self, parameter: ModelParameter): + self.model = parameter.model + self.tokenizer = parameter.tokenizer + self.config = parameter.config + + def compute_logits( + self, + input_ids: Tensor, + attn_mask: Optional[Tensor] = None, + kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None, + start_pos: int = 0 + ) -> Tuple[Tensor, int]: + with torch.inference_mode(): + outputs = self.model(input_ids, attn_mask, kv_caches, start_pos) + logits = outputs["logits"][:, -1, :] + cache_increase = input_ids.size(-1) + + return logits, cache_increase + + def to(self, *args, **kargs) -> Self: + self.model.to(*args, **kargs) + return self + + +class EmbeddingEncoderCore: + def __init__(self, parameter: ModelParameter): + self.model = parameter.model + self.tokenizer = parameter.tokenizer + self.config = parameter.config + + def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]: + with_batch = isinstance(sentence, list) + ids = self.tokenizer.encode(sentence) + batch_ids = ids if with_batch else [ids] + max_model_len = self.config.m_len + + all_fragments = [] + fragment_origin_idx = [] + + for i, seq in enumerate(batch_ids): + if len(seq) > max_model_len: + fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)] + all_fragments.extend(fragments) + fragment_origin_idx.extend([i] * len(fragments)) + else: + all_fragments.append(seq) + fragment_origin_idx.append(i) + + #if empty fragments + if not all_fragments or not ids: + return [] if with_batch else torch.tensor([]) + + device = next(self.model.parameters()).device + max_len = min(max(len(seq) for seq in all_fragments), max_model_len) + + padded_ids = [] + masks = [] + for seq in all_fragments: + pad_len = max_len - len(seq) + padded_seq = seq + [self.tokenizer.pad_id] * pad_len + mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq] + padded_ids.append(padded_seq) + masks.append(mask) + + input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long) + seq_mask = torch.tensor(masks, device=device, dtype=torch.bool) + + with torch.inference_mode(): + outputs = self.model(input_tensor, seq_mask)["hidden_states"] + # [num_fragments, seq_len, hidden_size] + fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1)) + + sentence_embs: List[Tensor] = [] + for i in range(len(batch_ids)): + indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i] + if indices is not None: + sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size] + length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1] + emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size] + sentence_embs.append(emb.flatten()) + + if with_batch: + return [emb.flatten() for emb in sentence_embs] + else: + return sentence_embs[0].flatten() + + def to(self, *args, **kargs) -> Self: + self.model.to(*args, **kargs) + return self \ No newline at end of file diff --git a/khaosz/core/generator.py b/khaosz/inference/generator.py similarity index 80% rename from khaosz/core/generator.py rename to khaosz/inference/generator.py index 3aba293..719fbd2 100644 --- a/khaosz/core/generator.py +++ b/khaosz/inference/generator.py @@ -1,7 +1,8 @@ import torch from torch import Tensor -from typing import List, Tuple, Union, Optional, Generator, Self -from khaosz.core.parameter import ModelParameter +from typing import List, Tuple, Union, Optional, Generator +from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore +from khaosz.config.param_config import ModelParameter def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str: @@ -168,96 +169,6 @@ class KVCacheManager: return self._seq_mask -class GeneratorCore: - def __init__(self, parameter: ModelParameter): - self.model = parameter.model - self.tokenizer = parameter.tokenizer - self.config = parameter.config - - def compute_logits( - self, - input_ids: Tensor, - attn_mask: Optional[Tensor] = None, - kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None, - start_pos: int = 0 - ) -> Tuple[Tensor, int]: - with torch.inference_mode(): - outputs = self.model(input_ids, attn_mask, kv_caches, start_pos) - logits = outputs["logits"][:, -1, :] - cache_increase = input_ids.size(-1) - - return logits, cache_increase - - def to(self, *args, **kargs) -> Self: - self.model.to(*args, **kargs) - return self - - -class EmbeddingEncoderCore: - def __init__(self, parameter: ModelParameter): - self.model = parameter.model - self.tokenizer = parameter.tokenizer - self.config = parameter.config - - def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]: - with_batch = isinstance(sentence, list) - ids = self.tokenizer.encode(sentence) - batch_ids = ids if with_batch else [ids] - max_model_len = self.config.m_len - - all_fragments = [] - fragment_origin_idx = [] - - for i, seq in enumerate(batch_ids): - if len(seq) > max_model_len: - fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)] - all_fragments.extend(fragments) - fragment_origin_idx.extend([i] * len(fragments)) - else: - all_fragments.append(seq) - fragment_origin_idx.append(i) - - #if empty fragments - if not all_fragments or not ids: - return [] if with_batch else torch.tensor([]) - - device = next(self.model.parameters()).device - max_len = min(max(len(seq) for seq in all_fragments), max_model_len) - - padded_ids = [] - masks = [] - for seq in all_fragments: - pad_len = max_len - len(seq) - padded_seq = seq + [self.tokenizer.pad_id] * pad_len - mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq] - padded_ids.append(padded_seq) - masks.append(mask) - - input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long) - seq_mask = torch.tensor(masks, device=device, dtype=torch.bool) - - with torch.inference_mode(): - outputs = self.model(input_tensor, seq_mask)["hidden_states"] - # [num_fragments, seq_len, hidden_size] - fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1)) - - sentence_embs: List[Tensor] = [] - for i in range(len(batch_ids)): - indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i] - if indices is not None: - sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size] - length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1] - emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size] - sentence_embs.append(emb.flatten()) - - if with_batch: - return [emb.flatten() for emb in sentence_embs] - else: - return sentence_embs[0].flatten() - - def to(self, *args, **kargs) -> Self: - self.model.to(*args, **kargs) - return self class TextGenerator(GeneratorCore): diff --git a/khaosz/model.py b/khaosz/khaosz.py similarity index 95% rename from khaosz/model.py rename to khaosz/khaosz.py index 3723f29..96fc988 100644 --- a/khaosz/model.py +++ b/khaosz/khaosz.py @@ -1,7 +1,7 @@ from torch import Tensor from typing import List, Tuple, Generator, Union -from khaosz.core.generator import ( +from khaosz.inference.generator import ( TextGenerator, ChatGenerator, StreamGenerator, @@ -9,7 +9,7 @@ from khaosz.core.generator import ( RetrievalGenerator, EmbeddingEncoder ) -from khaosz.core.parameter import ParameterLoader +from khaosz.config.param_config import ParameterLoader class Khaosz: diff --git a/khaosz/model/__init__.py b/khaosz/model/__init__.py new file mode 100644 index 0000000..c70b05c --- /dev/null +++ b/khaosz/model/__init__.py @@ -0,0 +1,17 @@ +from khaosz.model.module import ( + Linear, + RMSNorm, + MLP, + GQA, + DecoderBlock, +) +from khaosz.model.transformer import Transformer + +__all__ = [ + "Linear", + "RMSNorm", + "MLP", + "GQA", + "DecoderBlock", + "Transformer" +] \ No newline at end of file diff --git a/khaosz/core/transformer.py b/khaosz/model/module.py similarity index 56% rename from khaosz/core/transformer.py rename to khaosz/model/module.py index a324c63..bba68ce 100644 --- a/khaosz/core/transformer.py +++ b/khaosz/model/module.py @@ -1,12 +1,11 @@ -import json + import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.nn import init -from dataclasses import asdict, dataclass -from typing import List, Optional, Self, Tuple +from typing import Optional, Tuple def repeat_kv(x: Tensor, n_rep: int) -> Tensor: @@ -71,89 +70,6 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: return x_out.to(dtype) -def process_attention_mask( - seq_mask: Tensor, - start_pos: int = 0, - seq_len: int = 0, - is_causal: bool = False, - device: torch.device = "cuda", - dtype: torch.dtype = torch.float32 - ) -> Tensor: - """ - Create attention mask for GQA - Args: - seq_mask (Tensor): A tensor indicating whether each position is valid or not. - start_pos (int): The starting position of the sequence. - seq_len (int): The length of the sequence. - is_causal (bool): Whether the attention is causal or not. - device (torch.device): The device to use. - Returns: - Tensor: The attention mask tensor. - """ - - if seq_mask is None: - if start_pos != 0: - # for single prompt chat - seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) - else: - return None - - if seq_mask.dim() > 2: - # shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos) - # if ndim > 2, it's 4D tensor - return seq_mask - - batch_size = seq_mask.size(0) - seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool) - # (bsz, start_pos + seq_len) - expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len) - # (bsz, seq_len, start_pos + seq_len) - - if is_causal: - causal_mask = torch.tril( - torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device), - diagonal=start_pos - ) - causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len) - expanded_mask = expanded_mask & causal_mask - - attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device) - attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1) - # (bsz, 1, seq_len, seq_len + start_pos) - - return attention_mask - - -@dataclass -class TransformerConfig: - # basic config - vocab_size: Optional[int] = None - n_dim: Optional[int] = None - n_head: Optional[int] = None - n_layer: Optional[int] = None - m_len: Optional[int] = None - norm_eps: Optional[float] = None - d_ffn: Optional[int] = None - - # GQA - n_kvhead: Optional[int] = None - - - def load(self, config_path: str) -> Self: - with open(config_path, 'r') as f: - config: dict = json.load(f) - for key, value in config.items(): - if hasattr(self, key): - setattr(self, key, value) - - return self - - def save(self, config_path: str) -> None: - config_dict = asdict(self) - config_dict = {k: v for k, v in config_dict.items() if v is not None} - with open(config_path, 'w') as f: - json.dump(config_dict, f, indent=4) - class Linear(nn.Module): def __init__(self, in_dim: int, out_dim: int, bias: bool=False): @@ -287,60 +203,4 @@ class DecoderBlock(nn.Module): # feed forward x = self.ffn(self.norm_ffn(x)) + x - return x - - -class Transformer(nn.Module): - def __init__(self, config: TransformerConfig): - super().__init__() - self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim)) - self.layers = nn.ModuleList([ - DecoderBlock( - config.n_dim, - config.n_head, - config.d_ffn, - config.n_kvhead, - config.norm_eps - ) - for _ in range(config.n_layer) - ]) - self.norm = RMSNorm(config.n_dim, config.norm_eps) - self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len) - init.normal_(self.embedding, mean=0, std=0.02) - - def forward( - self, - input_ids: Tensor, - input_mask: Optional[Tensor]=None, - persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None, - start_pos: int = 0 - ) -> Tensor: - assert input_ids.ndim == 2 - seq_len = input_ids.size(-1) - x = F.embedding(input_ids, self.embedding) - - self.freq_cis = self.freq_cis.to(x.device) - freqs_cis = self.freq_cis[start_pos:start_pos+seq_len] - has_kvcache = persistent_key_values is not None - - attn_mask = process_attention_mask( - input_mask, - start_pos=start_pos, - seq_len=seq_len, - is_causal=has_kvcache, - device=x.device, - dtype=x.dtype - ) - - for i, layer in enumerate(self.layers): - kv_cache = persistent_key_values[i] if persistent_key_values else None - x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos) - - hidden_states = self.norm(x) - logits = F.linear(hidden_states, self.embedding) - - return { - "logits": logits, - "hidden_states": hidden_states - } - \ No newline at end of file + return x \ No newline at end of file diff --git a/khaosz/model/transformer.py b/khaosz/model/transformer.py new file mode 100644 index 0000000..6290fe6 --- /dev/null +++ b/khaosz/model/transformer.py @@ -0,0 +1,119 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch import Tensor +from torch.nn import init +from typing import List, Optional, Tuple + +from khaosz.config.model_config import TransformerConfig +from khaosz.model.module import DecoderBlock, RMSNorm, get_rotary_emb + + +def process_attention_mask( + seq_mask: Tensor, + start_pos: int = 0, + seq_len: int = 0, + is_causal: bool = False, + device: torch.device = "cuda", + dtype: torch.dtype = torch.float32 + ) -> Tensor: + """ + Create attention mask for GQA + Args: + seq_mask (Tensor): A tensor indicating whether each position is valid or not. + start_pos (int): The starting position of the sequence. + seq_len (int): The length of the sequence. + is_causal (bool): Whether the attention is causal or not. + device (torch.device): The device to use. + Returns: + Tensor: The attention mask tensor. + """ + + if seq_mask is None: + if start_pos != 0: + # for single prompt chat + seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) + else: + return None + + if seq_mask.dim() > 2: + # shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos) + # if ndim > 2, it's 4D tensor + return seq_mask + + batch_size = seq_mask.size(0) + seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool) + # (bsz, start_pos + seq_len) + expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len) + # (bsz, seq_len, start_pos + seq_len) + + if is_causal: + causal_mask = torch.tril( + torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device), + diagonal=start_pos + ) + causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len) + expanded_mask = expanded_mask & causal_mask + + attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device) + attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1) + # (bsz, 1, seq_len, seq_len + start_pos) + + return attention_mask + + +class Transformer(nn.Module): + def __init__(self, config: TransformerConfig): + super().__init__() + self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim)) + self.layers = nn.ModuleList([ + DecoderBlock( + config.n_dim, + config.n_head, + config.d_ffn, + config.n_kvhead, + config.norm_eps + ) + for _ in range(config.n_layer) + ]) + self.norm = RMSNorm(config.n_dim, config.norm_eps) + self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len) + init.normal_(self.embedding, mean=0, std=0.02) + + def forward( + self, + input_ids: Tensor, + input_mask: Optional[Tensor]=None, + persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None, + start_pos: int = 0 + ) -> Tensor: + assert input_ids.ndim == 2 + seq_len = input_ids.size(-1) + x = F.embedding(input_ids, self.embedding) + + self.freq_cis = self.freq_cis.to(x.device) + freqs_cis = self.freq_cis[start_pos:start_pos+seq_len] + has_kvcache = persistent_key_values is not None + + attn_mask = process_attention_mask( + input_mask, + start_pos=start_pos, + seq_len=seq_len, + is_causal=has_kvcache, + device=x.device, + dtype=x.dtype + ) + + for i, layer in enumerate(self.layers): + kv_cache = persistent_key_values[i] if persistent_key_values else None + x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos) + + hidden_states = self.norm(x) + logits = F.linear(hidden_states, self.embedding) + + return { + "logits": logits, + "hidden_states": hidden_states + } + \ No newline at end of file diff --git a/khaosz/trainer/__init__.py b/khaosz/trainer/__init__.py index 5c15fbd..e264dfe 100644 --- a/khaosz/trainer/__init__.py +++ b/khaosz/trainer/__init__.py @@ -1,6 +1,4 @@ -from khaosz.trainer.data_util import DatasetLoader from khaosz.trainer.trainer import Trainer -from khaosz.trainer.train_config import TrainConfig from khaosz.trainer.strategy import ( CosineScheduleConfig, SgdrScheduleConfig, @@ -17,19 +15,16 @@ from khaosz.trainer.train_callback import ( ) __all__ = [ - "DatasetLoader", "Trainer", - "TrainConfig", + "StrategyFactory", "CosineScheduleConfig", "SgdrScheduleConfig", - "StrategyFactory", "SchedulerFactory", # callback "TrainCallback", "ProgressBarCallback", "CheckpointCallback", - "TrainCallback", "SchedulerCallback", "StepMonitorCallback" ] \ No newline at end of file diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index c83921f..3ca86bb 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -106,7 +106,7 @@ class CheckpointCallback(TrainCallback): def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'): save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}") - context.checkpoint.sampler_state = context.sampler.state_dict() + # context.checkpoint.scheduler_state = context.sampler.state_dict() context.checkpoint.optimizer_state = context.optimizer.state_dict() context.checkpoint.save(save_path) self.last_ckpt_iter = context.current_iter diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index 50145e4..d9b8a9c 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -1,9 +1,10 @@ from dataclasses import dataclass, field, fields from typing import Optional, Self, TYPE_CHECKING from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import DataLoader -from khaosz.core.parameter import Checkpoint -from khaosz.trainer.data_util import RandomSampler +from khaosz.config.param_config import Checkpoint +from khaosz.data.data_util import ResumeableRandomSampler if TYPE_CHECKING: from khaosz.trainer.trainer import Trainer @@ -13,11 +14,11 @@ if TYPE_CHECKING: class TrainContext: dataloader: DataLoader = field(default=None) optimizer: Optimizer = field(default=None) - sampler: RandomSampler = field(default=None) + scheduler: LRScheduler = field(default=None) + checkpoint: Checkpoint = field(default=None) epoch: int = field(default=0) current_iter: int = field(default=0) loss: float = field(default=0.0) - checkpoint: Checkpoint = field(default=None) def asdict(self) -> dict: return {field.name: getattr(self, field.name) @@ -27,15 +28,7 @@ class TrainContext: class TrainContextBuilder: def __init__(self, trainer: 'Trainer'): self.trainer = trainer - self._context = TrainContext( - dataloader=None, - optimizer=None, - sampler=None, - epoch=0, - current_iter=0, - loss=0.0, - checkpoint=None - ) + self._context = TrainContext() def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: if checkpoint is None: @@ -43,32 +36,10 @@ class TrainContextBuilder: model=self.trainer.parameter.model, tokenizer=self.trainer.parameter.tokenizer, config=self.trainer.parameter.config, - sampler_state=None, - optimizer_state=None, - loss_list=[] ) self._context.checkpoint = checkpoint return self - def with_sampler(self) -> Self: - seed = self.trainer.train_config.random_seed - sampler = RandomSampler( - data_source=self.trainer.train_config.dataset, - seed=seed - ) - - if self._context.checkpoint and self._context.checkpoint.sampler_state: - sampler.load_state_dict(self._context.checkpoint.sampler_state) - - self._context.sampler = sampler - self._context.epoch = sampler.epoch - self._context.current_iter = sampler.current_iter - - if self._context.checkpoint: - self._context.checkpoint.sampler_state = sampler.state_dict() - - return self - def with_optimizer(self) -> Self: optimizer = self.trainer.train_config.optimizer @@ -82,11 +53,22 @@ class TrainContextBuilder: return self + def with_scheduler(self) -> Self: + return self + + def with_dataloader(self) -> Self: + resumeable_sampler = ResumeableRandomSampler( + data_source=self.trainer.train_config.dataset, + start_epoch=self._context.epoch, + start_iter=self._context.current_iter, + seed=self.trainer.train_config.random_seed + ) + dataloader = DataLoader( self.trainer.train_config.dataset, batch_size=self.trainer.train_config.batch_size, - sampler=self._context.sampler, + sampler=resumeable_sampler, num_workers=self.trainer.train_config.num_workers, pin_memory=self.trainer.train_config.pin_memory, prefetch_factor=self.trainer.train_config.prefetch_factor diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index c287358..14878d0 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -1,9 +1,9 @@ import logging from typing import Optional, List -from khaosz.core import ModelParameter, Checkpoint +from khaosz.config import ModelParameter, Checkpoint from khaosz.trainer.strategy import ScheduleConfig -from khaosz.trainer.train_config import TrainConfig +from khaosz.config.train_config import TrainConfig from khaosz.trainer.train_callback import ( TrainCallback, ProgressBarCallback, @@ -39,8 +39,8 @@ class Trainer: def _build_train_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: return (TrainContextBuilder(self) .with_checkpoint(checkpoint) - .with_sampler() .with_optimizer() + .with_scheduler() .with_dataloader() .build()) diff --git a/tests/conftest.py b/tests/conftest.py index cba44cd..17412c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,9 +9,11 @@ import pytest import matplotlib from torch.utils.data import Dataset -from khaosz.core import * -from khaosz.trainer import * -from khaosz.trainer.data_util import * +from khaosz.config.model_config import TransformerConfig +from khaosz.data.data_util import build_attention_mask, build_loss_mask +from khaosz.data.tokenizer import BpeTokenizer +from khaosz.model.transformer import Transformer + matplotlib.use("Agg") diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index bb4e62b..c93ccd5 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -1,8 +1,7 @@ import torch -from khaosz.core import * +from khaosz.config import * from khaosz.trainer import * -from khaosz.trainer.data_util import * def test_callback_integration(base_test_env, random_dataset): """Test that all callbacks are properly integrated""" diff --git a/tests/test_dataset_loader.py b/tests/test_dataset_loader.py index 3fe29ad..b45f51a 100644 --- a/tests/test_dataset_loader.py +++ b/tests/test_dataset_loader.py @@ -3,9 +3,9 @@ import torch import pickle import numpy as np -from khaosz.core import * from khaosz.trainer import * -from khaosz.trainer.data_util import * +from khaosz.data.data_util import * + def test_dataset_loader_random_paths(base_test_env): """Test dataset loader with multiple random paths""" diff --git a/tests/test_early_stopping.py b/tests/test_early_stopping.py index 9fb5f99..c6a3d9b 100644 --- a/tests/test_early_stopping.py +++ b/tests/test_early_stopping.py @@ -1,8 +1,7 @@ import torch - -from khaosz.core import * +from khaosz.config import * from khaosz.trainer import * -from khaosz.trainer.data_util import * + def test_early_stopping_simulation(base_test_env, early_stopping_dataset): """Simulate early stopping behavior""" diff --git a/tests/test_module.py b/tests/test_module.py index 8250570..9b341ab 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -5,8 +5,11 @@ import shutil import pytest import tempfile import safetensors.torch as st -from khaosz.core import * -from khaosz.core.generator import EmbeddingEncoderCore, GeneratorCore +from khaosz.trainer import * +from khaosz.config import * +from khaosz.model import * +from khaosz.data import * +from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore from tokenizers import pre_tokenizers @pytest.fixture diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 9109c29..3ff1568 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -1,14 +1,13 @@ -from khaosz.core import * from khaosz.trainer import * -from khaosz.trainer.data_util import * +from khaosz.data.data_util import * def test_random_sampler_consistency(random_dataset): """Test RandomSampler produces consistent results with same seed""" dataset = random_dataset # Create two samplers with same seed - sampler1 = RandomSampler(dataset, seed=42) - sampler2 = RandomSampler(dataset, seed=42) + sampler1 = ResumeableRandomSampler(dataset, seed=42) + sampler2 = ResumeableRandomSampler(dataset, seed=42) indices1 = list(iter(sampler1)) indices2 = list(iter(sampler2)) @@ -20,8 +19,8 @@ def test_random_sampler_different_seeds(random_dataset): dataset = random_dataset # Create two samplers with different seeds - sampler1 = RandomSampler(dataset, seed=42) - sampler2 = RandomSampler(dataset, seed=123) + sampler1 = ResumeableRandomSampler(dataset, seed=42) + sampler2 = ResumeableRandomSampler(dataset, seed=123) indices1 = list(iter(sampler1)) indices2 = list(iter(sampler2)) @@ -29,38 +28,13 @@ def test_random_sampler_different_seeds(random_dataset): # Very high probability they should be different assert indices1 != indices2 -def test_sampler_state_persistence(random_dataset): - """Test that sampler state is correctly saved and loaded""" - dataset = random_dataset - n = len(dataset) - - # Create sampler and get some indices - sampler = RandomSampler(dataset, seed=42) - iter1 = iter(sampler) - indices1 = [next(iter1) for _ in range(min(10, n))] - - # Save state - state_dict = sampler.state_dict() - - # Get more indices - indices2 = [next(iter1) for _ in range(min(10, n - len(indices1)))] - - # Create new sampler and load state - sampler2 = RandomSampler(dataset, seed=42) - sampler2.load_state_dict(state_dict) - - # Check that new sampler produces same sequence from saved point - iter2 = iter(sampler2) - indices3 = [next(iter2) for _ in range(min(10, n - len(indices1)))] - - assert indices2 == indices3 def test_sampler_across_epochs(random_dataset): """Test sampler behavior across multiple epochs""" dataset = random_dataset n = len(dataset) - sampler = RandomSampler(dataset, seed=42) + sampler = ResumeableRandomSampler(dataset, seed=42) # Get indices for first epoch epoch1_indices = list(iter(sampler)) diff --git a/tests/test_train_config.py b/tests/test_train_config.py index 60f2bd6..972f256 100644 --- a/tests/test_train_config.py +++ b/tests/test_train_config.py @@ -1,9 +1,10 @@ import torch import numpy as np -from khaosz.core import * + +from khaosz.config import * from khaosz.trainer import * -from khaosz.trainer.data_util import * +from khaosz.data.data_util import * def test_different_batch_sizes(base_test_env, random_dataset): """Test training with different batch sizes""" diff --git a/tests/test_train_strategy.py b/tests/test_train_strategy.py index 100130b..1ce6e23 100644 --- a/tests/test_train_strategy.py +++ b/tests/test_train_strategy.py @@ -1,9 +1,9 @@ import torch import numpy as np -from khaosz.core import * +from khaosz.config import * from khaosz.trainer import * -from khaosz.trainer.data_util import * +from khaosz.data.data_util import * def test_multi_turn_training(base_test_env, multi_turn_dataset): """Test training with multi-turn conversation data"""