From c01791ff5402686798ad01a6212e63e8df6b8f96 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 30 Mar 2026 00:55:15 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E6=8E=A8=E7=90=86?= =?UTF-8?q?=E9=83=A8=E5=88=86=E5=B7=A5=E5=8E=82=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- demo/generate_ar.py | 26 ++++--- demo/generate_batch.py | 21 ++++-- demo/generate_retrieve.py | 43 ----------- demo/stream_chat.py | 37 ++++++--- khaosz/__init__.py | 15 +--- khaosz/api.py | 138 ---------------------------------- khaosz/config/param_config.py | 11 ++- khaosz/inference/__init__.py | 4 + khaosz/inference/core.py | 33 ++++++-- khaosz/inference/generator.py | 123 ++++++++++++++++++++---------- khaosz/utils/__init__.py | 1 - khaosz/utils/retriever.py | 88 ---------------------- khaosz/utils/splitter.py | 127 ------------------------------- tests/module/test_module.py | 2 +- tools/generate.py | 76 ++++++------------- tools/perplexity.py | 28 ++++--- tools/train.py | 13 +--- 17 files changed, 227 insertions(+), 559 deletions(-) delete mode 100644 demo/generate_retrieve.py delete mode 100644 khaosz/api.py delete mode 100644 khaosz/utils/__init__.py delete mode 100644 khaosz/utils/retriever.py delete mode 100644 khaosz/utils/splitter.py diff --git a/demo/generate_ar.py b/demo/generate_ar.py index 77c216c..b8c0a26 100644 --- a/demo/generate_ar.py +++ b/demo/generate_ar.py @@ -1,27 +1,35 @@ import os import torch -from khaosz import Khaosz +from khaosz.config.param_config import ModelParameter +from khaosz.inference.core import disable_random_init +from khaosz.inference.generator import LoopGenerator, GenerationRequest PROJECT_ROOT = os.path.dirname( os.path.dirname(os.path.abspath(__file__))) def generate_text(): - model_dir = os.path.join(PROJECT_ROOT, "params") - model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) - + + with disable_random_init(): + model_dir = os.path.join(PROJECT_ROOT, "params") + param = ModelParameter.load(model_dir) + + param.to(device='cuda', dtype=torch.bfloat16) query = input(">> ") - response = model.text_generate( - query=query, + request = GenerationRequest( + query=query, temperature=0.8, top_p=0.95, - top_k=50 + top_k=50, + max_len=param.config.max_len, + history=None, + system_prompt=None, ) + generator = LoopGenerator(param) + response = generator.generate(request) print(response) - - if __name__ == "__main__": generate_text() \ No newline at end of file diff --git a/demo/generate_batch.py b/demo/generate_batch.py index 85a563f..3b39b04 100644 --- a/demo/generate_batch.py +++ b/demo/generate_batch.py @@ -1,22 +1,31 @@ import os import torch -from khaosz import Khaosz - +from khaosz.config.param_config import ModelParameter +from khaosz.inference.core import disable_random_init +from khaosz.inference.generator import BatchGenerator, GenerationRequest PROJECT_ROOT = os.path.dirname( os.path.dirname(os.path.abspath(__file__))) def batch_generate(): - model_dir = os.path.join(PROJECT_ROOT, "params") - model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) + with disable_random_init(): + model_dir = os.path.join(PROJECT_ROOT, "params") + param = ModelParameter.load(model_dir) + + param.to(device='cuda', dtype=torch.bfloat16) + generator = BatchGenerator(param) inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"] - responses = model.batch_generate( + request = GenerationRequest( query=inputs, temperature=0.8, top_p=0.95, - top_k=50 + top_k=50, + max_len=param.config.max_len, + history=None, + system_prompt=None, ) + responses = generator.generate(request) for q, r in zip(inputs, responses): print((q, r)) diff --git a/demo/generate_retrieve.py b/demo/generate_retrieve.py deleted file mode 100644 index c49ac54..0000000 --- a/demo/generate_retrieve.py +++ /dev/null @@ -1,43 +0,0 @@ -import os -import torch -from khaosz import Khaosz, SemanticTextSplitter, Retriever - - -PROJECT_ROOT = os.path.dirname( - os.path.dirname(os.path.abspath(__file__))) - -if __name__ == "__main__": - model_dir = os.path.join(PROJECT_ROOT, "params") - context_path = os.path.join(PROJECT_ROOT, "README.md") - - model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) - spliter = SemanticTextSplitter(model.encode) - retriever = Retriever() - text = open(context_path, "r", encoding="utf-8").read() - - res = spliter.split(text, threshold=0.8, window_size=1) - # print(("\n" + "+"*100 + "\n").join(res)) - - res_embs = model.encode(res) - for sentence, emb in zip(res, res_embs): - retriever.add_vector(sentence, emb) - - retrive_top_k = 5 - query = "作者设计了一个怎样的模型" - emb_query = model.encode(query) - retrieved = retriever.retrieve(emb_query, retrive_top_k) - retrieved_content = "\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)]) - - retrive_response = model.retrieve_generate( - retrieved=retrieved_content, - query=query, - temperature=0.8, - top_p=0.95, - top_k=50 - ) - - print("retrieve content:") - print(retrieved_content) - - print("\n\nretrive generate:") - print(retrive_response) \ No newline at end of file diff --git a/demo/stream_chat.py b/demo/stream_chat.py index deabc5c..abddd88 100644 --- a/demo/stream_chat.py +++ b/demo/stream_chat.py @@ -1,14 +1,21 @@ import os import torch -from khaosz import Khaosz +from khaosz.config.param_config import ModelParameter +from khaosz.inference.core import disable_random_init +from khaosz.inference.generator import StreamGenerator, GenerationRequest PROJECT_ROOT = os.path.dirname( os.path.dirname(os.path.abspath(__file__))) def chat(): - model_dir = os.path.join(PROJECT_ROOT, "params") - model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) + + with disable_random_init(): + model_dir = os.path.join(PROJECT_ROOT, "params") + param = ModelParameter.load(model_dir) + + param.to(device='cuda', dtype=torch.bfloat16) + generator = StreamGenerator(param) history = [] while True: @@ -16,17 +23,27 @@ def chat(): if query == "!exit": break - response_size = 0 - for response, history in model.stream_generate( - query=query, - history=history, + request = GenerationRequest( + query=query, temperature=0.8, top_p=0.95, - top_k=50 - ): + top_k=50, + max_len=param.config.max_len, + history=history, + system_prompt=None, + ) + + response_size = 0 + full_response = "" + for response in generator.generate(request): + # response is the cumulative response up to current token print(response[response_size:], end="", flush=True) response_size = len(response) - + full_response = response + + # After generation, update history + history.append((query, full_response.strip())) + if __name__ == "__main__": chat() \ No newline at end of file diff --git a/khaosz/__init__.py b/khaosz/__init__.py index f9a14f1..5c02a8e 100644 --- a/khaosz/__init__.py +++ b/khaosz/__init__.py @@ -1,17 +1,11 @@ __version__ = "1.3.2" __author__ = "ViperEkura" -from khaosz.api import Khaosz from khaosz.config import ( ModelConfig, TrainConfig, ) from khaosz.model.transformer import Transformer -from khaosz.utils.retriever import Retriever -from khaosz.utils.splitter import ( - SemanticTextSplitter, - PriorityTextSplitter -) from khaosz.data import ( DatasetLoader, BpeTokenizer @@ -22,8 +16,8 @@ from khaosz.inference.generator import ( StreamGenerator, BatchGenerator, EmbeddingEncoder, + GeneratorFactory ) - from khaosz.trainer import ( Trainer, StrategyFactory, @@ -31,14 +25,8 @@ from khaosz.trainer import ( ) __all__ = [ - "Khaosz", - "Transformer", - "Retriever", - "SemanticTextSplitter", - "PriorityTextSplitter", - "ModelConfig", "TrainConfig", @@ -50,6 +38,7 @@ __all__ = [ "StreamGenerator", "BatchGenerator", "EmbeddingEncoder", + "GeneratorFactory", "Trainer", "StrategyFactory", diff --git a/khaosz/api.py b/khaosz/api.py deleted file mode 100644 index ef9f321..0000000 --- a/khaosz/api.py +++ /dev/null @@ -1,138 +0,0 @@ -from torch import nn -from torch import Tensor -from contextlib import contextmanager -from typing import List, Tuple, Generator, Union - -from khaosz.inference.generator import ( - GenerationRequest, - LoopGenerator, - StreamGenerator, - BatchGenerator, - EmbeddingEncoder -) -from khaosz.config.param_config import ModelParameter - -@contextmanager -def disable_random_init(): - init_functions = [ - 'xavier_normal_', 'xavier_uniform_', - 'kaiming_normal_', 'kaiming_uniform_', - 'zeros_', 'ones_', 'constant_', - 'normal_', 'uniform_' - ] - original_funcs = {} - for name in init_functions: - if hasattr(nn.init, name): - original_funcs[name] = getattr(nn.init, name) - setattr(nn.init, name, lambda *args, **kwargs: None) - try: - yield - finally: - for name, orig_func in original_funcs.items(): - setattr(nn.init, name, orig_func) - - -class Khaosz: - def __init__(self, model_dir: str): - with disable_random_init(): - self.parameter = ModelParameter() - self.parameter.load(model_dir) - - def to(self, *args, **kwargs): - self.parameter.to(*args, **kwargs) - return self - - def generate( - self, - query: str, - history: List[Tuple[str, str]]=None, - temperature: float=0.8, - top_k: int=50, - top_p: float=0.95, - ) -> str: - generator = LoopGenerator(self.parameter) - return generator.generate( - GenerationRequest( - top_k, top_p, temperature, - self.parameter.config.max_len, - query=query, - history=history, - build_prompt=True - )) - - def batch_generate( - self, - query: List[str], - history: List[Tuple[str, str]]=None, - temperature: float=0.8, - top_k: int=50, - top_p: float=0.95, - ) -> List[str]: - generator = BatchGenerator(self.parameter) - return generator.generate( - GenerationRequest( - top_k, top_p, temperature, - self.parameter.config.max_len, - query=query, - history=history, - build_prompt=True - )) - - def stream_generate( - self, - query: str, - history: List[Tuple[str, str]]=None, - temperature: float=0.8, - top_k: int=50, - top_p: float=0.95, - ) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]: - stream_generator = StreamGenerator(self.parameter) - return stream_generator.generate( - GenerationRequest( - top_k, top_p, temperature, - self.parameter.config.max_len, - query=query, - history=history, - build_prompt=True - )) - - def retrieve_generate( - self, - retrieved, - query: str, - history: List[Tuple[str, str]] = None, - temperature: float=0.8, - top_k: int=50, - top_p: float=0.95, - ) -> str: - generator = LoopGenerator(self.parameter) - return generator.generate( - GenerationRequest( - top_k, top_p, temperature, - self.parameter.config.max_len, - query=query, - history=history, - system_prompt=retrieved, - build_prompt=True - )) - - def text_generate( - self, - query: str, - temperature: float=0.8, - top_k: int=50, - top_p: float=0.95, - ) -> str: - generator = LoopGenerator(self.parameter) - return generator.generate( - GenerationRequest( - top_k, top_p, temperature, - self.parameter.config.max_len, - query=query, - build_prompt=False - )) - - - def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]: - encoder = EmbeddingEncoder(self.parameter) - return encoder.encode(sentence) \ No newline at end of file diff --git a/khaosz/config/param_config.py b/khaosz/config/param_config.py index 4d29813..6036a08 100644 --- a/khaosz/config/param_config.py +++ b/khaosz/config/param_config.py @@ -72,9 +72,12 @@ class BaseModelIO: class ModelParameter(BaseModelIO): """Container for model parameters with serialization capabilities.""" - def save(self, save_dir: Union[str, Path]): - self.save_components(save_dir) + @classmethod + def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]): + instance.save_components(save_dir) - def load(self, load_dir: Union[str, Path]) -> "ModelParameter": - return self.load_components(load_dir) + @classmethod + def load(cls, load_dir: Union[str, Path]) -> "ModelParameter": + instance = cls() + return instance.load_components(load_dir) diff --git a/khaosz/inference/__init__.py b/khaosz/inference/__init__.py index f94aec6..c652dc4 100644 --- a/khaosz/inference/__init__.py +++ b/khaosz/inference/__init__.py @@ -1,4 +1,5 @@ from khaosz.inference.core import ( + disable_random_init, GeneratorCore, EmbeddingEncoderCore, KVCacheManager, @@ -10,9 +11,11 @@ from khaosz.inference.generator import ( StreamGenerator, BatchGenerator, EmbeddingEncoder, + GeneratorFactory ) __all__ = [ + "disable_random_init", "GeneratorCore", "EmbeddingEncoderCore", "KVCacheManager", @@ -22,4 +25,5 @@ __all__ = [ "StreamGenerator", "BatchGenerator", "EmbeddingEncoder", + "GeneratorFactory" ] \ No newline at end of file diff --git a/khaosz/inference/core.py b/khaosz/inference/core.py index ddbb24a..d2fbcc7 100644 --- a/khaosz/inference/core.py +++ b/khaosz/inference/core.py @@ -1,5 +1,8 @@ import torch +import torch.nn as nn + from torch import Tensor +from contextlib import contextmanager from typing import Any, Callable, List, Tuple, Union, Optional, Self from khaosz.config import ModelParameter, ModelConfig @@ -54,6 +57,26 @@ def apply_sampling_strategies( return logits +@contextmanager +def disable_random_init(): + init_functions = [ + 'xavier_normal_', 'xavier_uniform_', + 'kaiming_normal_', 'kaiming_uniform_', + 'zeros_', 'ones_', 'constant_', + 'normal_', 'uniform_' + ] + original_funcs = {} + for name in init_functions: + if hasattr(nn.init, name): + original_funcs[name] = getattr(nn.init, name) + setattr(nn.init, name, lambda *args, **kwargs: None) + try: + yield + finally: + for name, orig_func in original_funcs.items(): + setattr(nn.init, name, orig_func) + + class GeneratorCore: def __init__(self, parameter: ModelParameter): self.model = parameter.model @@ -82,10 +105,6 @@ class GeneratorCore: return next_token_id, cache_increase - def to(self, *args, **kargs) -> Self: - self.model.to(*args, **kargs) - return self - def generate_loop( self, input_ids: Tensor, @@ -115,6 +134,10 @@ class GeneratorCore: break return ids + + def to(self, *args, **kargs) -> Self: + self.model.to(*args, **kargs) + return self class EmbeddingEncoderCore: @@ -203,7 +226,7 @@ class KVCacheManager: self._kv_cache: Tuple[Tensor, Tensor] = None self._seq_mask: Tensor = None self._initialize() - + def _initialize(self): k_cache = torch.zeros( (self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim), diff --git a/khaosz/inference/generator.py b/khaosz/inference/generator.py index 877807b..d42489a 100644 --- a/khaosz/inference/generator.py +++ b/khaosz/inference/generator.py @@ -9,33 +9,37 @@ from khaosz.config.param_config import ModelParameter HistoryType = List[Tuple[str, str]] def build_prompt( - query: str, - init_prompt: Optional[str] = None, - history: Optional[List[Tuple[str, str]]] = None - ) -> str: - """ - Build prompt in ChatML format for query and history - - Args: - query(str): query string - history(Optional[List[Tuple[str, str]]]): history list of query and response - - Returns: - str: prompt string in ChatML format - + query: str, + system_prompt: Optional[str] = None, + history: Optional[HistoryType] = None +) -> str: """ - prompt = f"<|im_start|>system\n{init_prompt}<|im_end|>\n" if init_prompt else "" - + Build prompt in ChatML format for query and history. + + Args: + query (str): query string. + system_prompt (Optional[str]): system prompt string. + history (Optional[HistoryType]): history list of query and response. + + Returns: + str: prompt string in ChatML format. + """ + result = "" + + if system_prompt: + result += f"<|im_start|>system\n{system_prompt}<|im_end|>\n" + # (convert tuple format to ChatML) if history: for user_msg, assistant_msg in history: - prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n" - prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n" - - prompt += f"<|im_start|>user\n{query}<|im_end|>\n" - prompt += "<|im_start|>assistant\n" - - return prompt + result += f"<|im_start|>user\n{user_msg}<|im_end|>\n" + result += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n" + + result += f"<|im_start|>user\n{query}<|im_end|>\n" + result += "<|im_start|>assistant\n" + + return result + def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]: """ @@ -59,8 +63,21 @@ def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int] return new_ids_list, max_ids_len + @dataclass class GenerationRequest: + """ + Request parameters for text generation. + + Attributes: + top_k: Top-k sampling parameter. + top_p: Top-p (nucleus) sampling parameter. + temperature: Sampling temperature. + max_len: Maximum generation length. + query: Input query (string or list of strings for batch). + history: Conversation history. + system_prompt: System prompt for the conversation. + """ top_k: int top_p: float temperature: float @@ -70,8 +87,6 @@ class GenerationRequest: history: Optional[Union[HistoryType, List[HistoryType]]] = None system_prompt: Optional[str] = None - build_prompt: bool = True - def __post_init__(self): if not isinstance(self.top_k, int) or self.top_k < 0: raise ValueError("top_k must be a non-negative integer") @@ -89,19 +104,21 @@ class LoopGenerator(GeneratorCore): device = next(self.model.parameters()).device cache_manager = KVCacheManager(self.config, 1, device=device) - input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query - ids = self.tokenizer.encode(input_args) + prompt = build_prompt(request.query, request.history) + ids = self.tokenizer.encode(prompt) input_ids = torch.tensor([ids], device=device, dtype=torch.long) start_cache_pos = len(ids) - cur_cache_pos = 0 self.model.eval() kv_caches = cache_manager.get_kvcache() ids = self.generate_loop( - input_ids, ids, request.temperature, request.top_k, request.top_p, + input_ids, + ids, + request.temperature, + request.top_k, + request.top_p, kv_caches=kv_caches, - start_pos=cur_cache_pos ) response = self.tokenizer.decode(ids[start_cache_pos:]) @@ -112,16 +129,12 @@ class StreamGenerator(GeneratorCore): def __init__(self, parameter: ModelParameter): super().__init__(parameter) - def generate(self, request: GenerationRequest) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]: - - if request.history is None: - request.history = [] - + def generate(self, request: GenerationRequest) -> Generator[str, None, None]: device = next(self.model.parameters()).device cache_manager = KVCacheManager(self.config, 1, device=device) - input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query - ids = self.tokenizer.encode(input_args) + prompt = build_prompt(request.query, request.history) + ids = self.tokenizer.encode(prompt) input_ids = torch.tensor([ids], device=device, dtype=torch.long) start_cache_pos = len(ids) @@ -141,10 +154,10 @@ class StreamGenerator(GeneratorCore): cur_cache_pos += cache_increase response = self.tokenizer.decode(ids[start_cache_pos:]) - yield response, request.history + [(request.query, response)] + yield response if next_token_id.item() in self.tokenizer.stop_ids: - yield response + "\n", request.history + [(request.query, response)] + yield response + "\n" break @@ -217,4 +230,36 @@ class EmbeddingEncoder(EmbeddingEncoderCore): def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]: return super().encode(sentence) + + +class GeneratorFactory: + """Factory class for creating appropriate generator instances based on request features.""" + + @staticmethod + def create_generator(parameter: ModelParameter, request: GenerationRequest): + """ + Create a generator based on the characteristics of GenerationRequest. + Args: + parameter: Model parameters + request: Generation request + + Returns: + Subclass instance of GeneratorCore + """ + + # Streaming generation detection: check stream field + if request.stream: + return StreamGenerator(parameter) + + # Batch generation detection: query is a list + if isinstance(request.query, list): + return BatchGenerator(parameter) + + # Default return LoopGenerator + return LoopGenerator(parameter) + + @staticmethod + def create_encoder(parameter: ModelParameter): + """Create an EmbeddingEncoder instance""" + return EmbeddingEncoder(parameter) \ No newline at end of file diff --git a/khaosz/utils/__init__.py b/khaosz/utils/__init__.py deleted file mode 100644 index 5bd6efa..0000000 --- a/khaosz/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# init file \ No newline at end of file diff --git a/khaosz/utils/retriever.py b/khaosz/utils/retriever.py deleted file mode 100644 index 1264474..0000000 --- a/khaosz/utils/retriever.py +++ /dev/null @@ -1,88 +0,0 @@ -import torch -import sqlite3 -import numpy as np -from torch import Tensor -from typing import Dict, List, Tuple - - -class Retriever: - def __init__(self, db_path=None): - self.data: Dict[str, Tensor] = {} - self.embedding_cache: Tensor = None - self.is_caculated: bool = False - - if db_path is not None: - self.load(db_path) - - def retrieve(self, query: Tensor, top_k: int) -> List[Tuple[str, float]]: - if not self.data: - return [] - - query = query.flatten().unsqueeze(1) # [dim, 1] - norm_embeddings = self._embeddings.to( - device=query.device, - dtype=query.dtype - ) # [n_vectors, dim] - sim_scores = torch.matmul(norm_embeddings, query).squeeze() # [n_vectors] - - top_k = min(top_k, len(self.data)) - indices = sim_scores.topk(top_k).indices - keys = list(self.data.keys()) - - return [(keys[i], sim_scores[i].item()) for i in indices] - - def add_vector(self, key: str, vector_data: Tensor): - self.is_caculated = False - self.data[key] = vector_data.flatten().float().cpu() - - def delete_vector(self, key: str): - self.is_caculated = False - self.data.pop(key, None) - - def save(self, db_path): - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - self._init_db(cursor) - cursor.execute('DELETE FROM vectors') - - for item, vec in self.data.items(): - vec_bytes = vec.numpy().tobytes() - cursor.execute('INSERT OR REPLACE INTO vectors (key, vector) VALUES (?, ?)', - (item, vec_bytes)) - - conn.commit() - conn.close() - - def load(self, db_path): - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - self._init_db(cursor) - cursor.execute('SELECT key, vector FROM vectors') - rows = cursor.fetchall() - self.data = {} - - for row in rows: - key, vec_bytes = row - vec_numpy = np.frombuffer(vec_bytes, dtype=np.float32).copy() - vec = torch.from_numpy(vec_numpy) - self.data[key] = vec - - conn.close() - - def _init_db(self,cursor: sqlite3.Cursor): - # Create table if not exists (in case loading from a new database) - cursor.execute(''' - CREATE TABLE IF NOT EXISTS vectors ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - key TEXT UNIQUE NOT NULL, - vector BLOB NOT NULL - )''') - - @property - def _embeddings(self) -> Tensor: - if not self.is_caculated: - embeddings = torch.stack(list(self.data.values())) - norm_embeddings = embeddings / torch.norm(embeddings, dim=-1, keepdim=True) - self.embedding_cache = norm_embeddings - - return self.embedding_cache \ No newline at end of file diff --git a/khaosz/utils/splitter.py b/khaosz/utils/splitter.py deleted file mode 100644 index 7f15fd3..0000000 --- a/khaosz/utils/splitter.py +++ /dev/null @@ -1,127 +0,0 @@ -import re -import torch -import torch.nn.functional as F - -from abc import ABC, abstractmethod -from torch import Tensor -from typing import List, Callable, Optional - - -class BaseTextSplitter(ABC): - def __init__( - self, - max_len: int = 512, - chunk_overlap: int = 0, - ): - if max_len <= 0: - raise ValueError("max_len must be > 0") - if chunk_overlap < 0: - raise ValueError("chunk_overlap must be >= 0") - - self.max_len = max_len - self.chunk_overlap = chunk_overlap - - @abstractmethod - def split(self, text: str, **kwargs) -> List[str]: - raise NotImplementedError - - def preprocess(self, text: str) -> str: - return text.strip() - - def postprocess(self, chunks: List[str]) -> List[str]: - return [chunk.strip() for chunk in chunks if chunk.strip()] - - -class PriorityTextSplitter(BaseTextSplitter): - def __init__( - self, - separators: List[str], - max_len: int = 512, - chunk_overlap: int = 0, - ): - super().__init__(max_len=max_len, chunk_overlap=chunk_overlap) - if not separators: - raise ValueError("separators must be a non-empty list") - self.separators = separators - - def split(self, text: str) -> List[str]: - text = self.preprocess(text) - for sep in self.separators: - parts = text.split(sep) - - valid_parts = [p.strip() for p in parts if p.strip()] - if len(valid_parts) > 1: - return self.postprocess(valid_parts) - return [text] - - -class SemanticTextSplitter(BaseTextSplitter): - - DEFAULT_PATTERN = r'(?<=[。!?!?])(?=(?:[^"\'‘’“”]*["\'‘’“”][^"\'‘’“”]*["\'‘’“”])*[^"\'‘’“”]*$)' - - def __init__( - self, - embedding_func: Callable[[List[str]], List[Tensor]], - pattern: Optional[str] = None, - max_len: int = 512, - chunk_overlap: int = 0, - ): - super().__init__(max_len=max_len, chunk_overlap=chunk_overlap) - if not callable(embedding_func): - raise TypeError("embedding_func must be callable") - self.embedding_func = embedding_func - self.pattern = pattern or SemanticTextSplitter.DEFAULT_PATTERN - - def split( - self, - text: str, - threshold: float = 0.5, - window_size: int = 1, - ) -> List[str]: - text = self.preprocess(text) - sentences = [s.strip() for s in re.split(self.pattern, text) if s.strip()] - - if len(sentences) <= 1: - return self.postprocess(sentences) - - try: - sentence_embs = self.embedding_func(sentences) - except Exception as e: - raise RuntimeError(f"Embedding generation failed: {e}") - - if len(sentence_embs) != len(sentences): - raise ValueError("Embedding function must return one vector per sentence") - - chunks = [] - emb_tensor = torch.stack(sentence_embs) # shape: [N, D] - current_chunk: List[str] = [sentences[0]] - - for i in range(1, len(sentences)): - start_prev = max(0, i - window_size) - end_prev = i - start_next = i - end_next = min(len(sentences), i + window_size) - - prev_window_emb = emb_tensor[start_prev:end_prev].mean(dim=0) - next_window_emb = emb_tensor[start_next:end_next].mean(dim=0) - - similarity = F.cosine_similarity( - prev_window_emb.unsqueeze(0), - next_window_emb.unsqueeze(0), - dim=1 - ).item() - - dynamic_threshold = max(threshold * (1 - 0.03 * (end_next - start_prev)), 0.2) - - if similarity < dynamic_threshold: - chunks.append(" ".join(current_chunk)) - overlap_start = max(0, len(current_chunk) - self.chunk_overlap) - current_chunk = current_chunk[overlap_start:] - current_chunk.append(sentences[i]) - else: - current_chunk.append(sentences[i]) - - if current_chunk: - chunks.append(" ".join(current_chunk)) - - return self.postprocess(chunks) \ No newline at end of file diff --git a/tests/module/test_module.py b/tests/module/test_module.py index 9d5857e..dce5508 100644 --- a/tests/module/test_module.py +++ b/tests/module/test_module.py @@ -54,7 +54,7 @@ def test_env(request: pytest.FixtureRequest): def test_model_parameter(test_env): save_dir = os.path.join(test_env["test_dir"], "save") model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"]) - model_param.save(save_dir) + ModelParameter.save(model_param, save_dir) assert os.path.exists(os.path.join(save_dir, "model.safetensors")) assert os.path.exists(os.path.join(save_dir, "tokenizer.json")) diff --git a/tools/generate.py b/tools/generate.py index cedefdc..d5c96b0 100644 --- a/tools/generate.py +++ b/tools/generate.py @@ -1,45 +1,10 @@ import torch import json -import torch import argparse -from khaosz import Khaosz -from typing import List -from tqdm import tqdm - - -def batch_generate( - model: Khaosz, - query: List[str], - temperature: float, - top_k: int, - top_p: float, - batch_size: int, -) -> List: - assert batch_size > 0 - sorted_query = sorted(query, key=lambda x: len(x), reverse=True) - original_indices = {query: idx for idx, query in enumerate(query)} - - responses = [None] * len(query) - total_batches = (len(sorted_query) + batch_size - 1) // batch_size - - for i in tqdm(range(0, total_batches * batch_size, batch_size), desc="Generating responses"): - batch_query = sorted_query[i: min(i + batch_size, len(query))] - if not isinstance(batch_query, list): - batch_query = [batch_query] - - batch_responses = model.batch_generate( - query=batch_query, - temperature=temperature, - top_k=top_k, - top_p=top_p - ) - - for query, response in zip(batch_query, batch_responses): - original_idx = original_indices[query] - responses[original_idx] = response - - return responses +from khaosz.config.param_config import ModelParameter +from khaosz.inference.generator import BatchGenerator, GenerationRequest +from khaosz.inference.core import disable_random_init def processor( @@ -53,24 +18,31 @@ def processor( question_key: str, response_key: str, ): - model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) - + with disable_random_init(): + param = ModelParameter.load(model_dir) + + param.to(device='cuda', dtype=torch.bfloat16) + generator = BatchGenerator(param) + with open(input_json_file, "r", encoding='utf-8') as f: input_data = [json.loads(line) for line in f] - query = [item[question_key] for item in input_data] - - responses = batch_generate( - model=model, - query=query, + + queries = [item[question_key] for item in input_data] + + request = GenerationRequest( + query=queries, temperature=temperature, - top_k=top_k, top_p=top_p, - batch_size=batch_size + top_k=top_k, + max_len=param.config.max_len, + history=None, + system_prompt=None, ) - - # Write output in JSONL format + + responses = generator.generate(request) + with open(output_json_file, "w", encoding='utf-8') as f: - for query, response in zip(query, responses): + for query, response in zip(queries, responses): output_item = {question_key: query, response_key: response} f.write(json.dumps(output_item, ensure_ascii=False) + '\n') @@ -89,4 +61,6 @@ if __name__ == "__main__": parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.") args = parser.parse_args() - processor(**vars(args)) \ No newline at end of file + + with torch.inference_mode(): + processor(**vars(args)) \ No newline at end of file diff --git a/tools/perplexity.py b/tools/perplexity.py index fe3bf56..2a9cc65 100644 --- a/tools/perplexity.py +++ b/tools/perplexity.py @@ -6,7 +6,9 @@ import argparse import tqdm from torch import Tensor -from khaosz import Khaosz +from khaosz.config.param_config import ModelParameter +from khaosz.inference.core import disable_random_init + def compute_perplexity( model: nn.Module, @@ -45,22 +47,23 @@ def process_file( batch_size: int, text_key: str ): - model = Khaosz(model_dir).to(device="cuda", dtype=torch.bfloat16) - tokenizer = model.parameter.tokenizer + with disable_random_init(): + param = ModelParameter.load(model_dir) + + param.to(device='cuda', dtype=torch.bfloat16) + model = param.model + tokenizer = param.tokenizer with open(input_file, "r", encoding='utf-8') as f: input_data = [json.loads(line) for line in f] texts = [item[text_key] for item in input_data] encoded_texts = [tokenizer.encode(text) for text in texts] - output_data = [] for i in tqdm(range(0, len(encoded_texts), batch_size), desc="Computing perplexity"): batch_encoded = encoded_texts[i:i + batch_size] batch_texts = texts[i:i + batch_size] - - # Pad sequences to the same length (left padding) max_len = max(len(seq) for seq in batch_encoded) padded_ids = [] masks = [] @@ -74,10 +77,7 @@ def process_file( input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long) input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool) - - # Compute perplexity - with torch.inference_mode(): - perplexity = compute_perplexity(model.parameter.model, input_ids, input_mask) + perplexity = compute_perplexity(model, input_ids, input_mask) for text, ppl in zip(batch_texts, perplexity): output_data.append({text_key: text, "ppl": float(ppl.item())}) @@ -87,16 +87,14 @@ def process_file( f.write(json.dumps(item, ensure_ascii=False) + '\n') -def main(): +if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.") parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.") parser.add_argument("--input_file", type=str, required=True, help="Path to the input file.") parser.add_argument("--output_file", type=str, required=True, help="Path to the output file.") parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation.") parser.add_argument("--text_key", type=str, default="text", help="Key for the text field in the input data.") - args = parser.parse_args() - process_file(**vars(args)) -if __name__ == "__main__": - main() + with torch.inference_mode(): + process_file(**vars(args)) diff --git a/tools/train.py b/tools/train.py index 1186206..1a09173 100644 --- a/tools/train.py +++ b/tools/train.py @@ -16,7 +16,7 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Train the Transformer model.") - parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.") + parser.add_argument("--train_type", type=str, required=True, choices=["seq", "sft", "dpo"], help="Train type.") parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.") parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.") @@ -67,18 +67,14 @@ def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler return SchedulerFactory.load(optimizer, **kwargs) def prepare_checkpoint(model: nn.Module) -> dict: - if isinstance(model, torch.nn.parallel.DistributedDataParallel): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() - return state_dict + return model.module.state_dict() def train( train_type: str, param_path: str, data_root_path: str, - max_lr: int, + max_lr: float, n_epoch: int, batch_size: int, start_epoch: int, @@ -104,8 +100,7 @@ def train( assert train_type in ["seq", "sft", "dpo"] assert os.path.exists(param_path) - parameter = ModelParameter() - parameter.load(param_path) + parameter = ModelParameter.load(param_path) if window_size is None: window_size = parameter.config.max_len