feat: 增加推理部分工厂模式

This commit is contained in:
ViperEkura 2026-03-30 00:55:15 +08:00
parent 980299cd54
commit c01791ff54
17 changed files with 227 additions and 559 deletions

View File

@ -1,27 +1,35 @@
import os import os
import torch import torch
from khaosz import Khaosz from khaosz.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init
from khaosz.inference.generator import LoopGenerator, GenerationRequest
PROJECT_ROOT = os.path.dirname( PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__))) os.path.dirname(os.path.abspath(__file__)))
def generate_text(): def generate_text():
model_dir = os.path.join(PROJECT_ROOT, "params")
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) with disable_random_init():
model_dir = os.path.join(PROJECT_ROOT, "params")
param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16)
query = input(">> ") query = input(">> ")
response = model.text_generate( request = GenerationRequest(
query=query, query=query,
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
top_k=50 top_k=50,
max_len=param.config.max_len,
history=None,
system_prompt=None,
) )
generator = LoopGenerator(param)
response = generator.generate(request)
print(response) print(response)
if __name__ == "__main__": if __name__ == "__main__":
generate_text() generate_text()

View File

@ -1,22 +1,31 @@
import os import os
import torch import torch
from khaosz import Khaosz from khaosz.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init
from khaosz.inference.generator import BatchGenerator, GenerationRequest
PROJECT_ROOT = os.path.dirname( PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__))) os.path.dirname(os.path.abspath(__file__)))
def batch_generate(): def batch_generate():
model_dir = os.path.join(PROJECT_ROOT, "params") with disable_random_init():
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) model_dir = os.path.join(PROJECT_ROOT, "params")
param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16)
generator = BatchGenerator(param)
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"] inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
responses = model.batch_generate( request = GenerationRequest(
query=inputs, query=inputs,
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
top_k=50 top_k=50,
max_len=param.config.max_len,
history=None,
system_prompt=None,
) )
responses = generator.generate(request)
for q, r in zip(inputs, responses): for q, r in zip(inputs, responses):
print((q, r)) print((q, r))

View File

@ -1,43 +0,0 @@
import os
import torch
from khaosz import Khaosz, SemanticTextSplitter, Retriever
PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
if __name__ == "__main__":
model_dir = os.path.join(PROJECT_ROOT, "params")
context_path = os.path.join(PROJECT_ROOT, "README.md")
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
spliter = SemanticTextSplitter(model.encode)
retriever = Retriever()
text = open(context_path, "r", encoding="utf-8").read()
res = spliter.split(text, threshold=0.8, window_size=1)
# print(("\n" + "+"*100 + "\n").join(res))
res_embs = model.encode(res)
for sentence, emb in zip(res, res_embs):
retriever.add_vector(sentence, emb)
retrive_top_k = 5
query = "作者设计了一个怎样的模型"
emb_query = model.encode(query)
retrieved = retriever.retrieve(emb_query, retrive_top_k)
retrieved_content = "\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)])
retrive_response = model.retrieve_generate(
retrieved=retrieved_content,
query=query,
temperature=0.8,
top_p=0.95,
top_k=50
)
print("retrieve content:")
print(retrieved_content)
print("\n\nretrive generate:")
print(retrive_response)

View File

@ -1,14 +1,21 @@
import os import os
import torch import torch
from khaosz import Khaosz from khaosz.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init
from khaosz.inference.generator import StreamGenerator, GenerationRequest
PROJECT_ROOT = os.path.dirname( PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__))) os.path.dirname(os.path.abspath(__file__)))
def chat(): def chat():
model_dir = os.path.join(PROJECT_ROOT, "params")
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) with disable_random_init():
model_dir = os.path.join(PROJECT_ROOT, "params")
param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16)
generator = StreamGenerator(param)
history = [] history = []
while True: while True:
@ -16,17 +23,27 @@ def chat():
if query == "!exit": if query == "!exit":
break break
response_size = 0 request = GenerationRequest(
for response, history in model.stream_generate( query=query,
query=query,
history=history,
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
top_k=50 top_k=50,
): max_len=param.config.max_len,
history=history,
system_prompt=None,
)
response_size = 0
full_response = ""
for response in generator.generate(request):
# response is the cumulative response up to current token
print(response[response_size:], end="", flush=True) print(response[response_size:], end="", flush=True)
response_size = len(response) response_size = len(response)
full_response = response
# After generation, update history
history.append((query, full_response.strip()))
if __name__ == "__main__": if __name__ == "__main__":
chat() chat()

View File

@ -1,17 +1,11 @@
__version__ = "1.3.2" __version__ = "1.3.2"
__author__ = "ViperEkura" __author__ = "ViperEkura"
from khaosz.api import Khaosz
from khaosz.config import ( from khaosz.config import (
ModelConfig, ModelConfig,
TrainConfig, TrainConfig,
) )
from khaosz.model.transformer import Transformer from khaosz.model.transformer import Transformer
from khaosz.utils.retriever import Retriever
from khaosz.utils.splitter import (
SemanticTextSplitter,
PriorityTextSplitter
)
from khaosz.data import ( from khaosz.data import (
DatasetLoader, DatasetLoader,
BpeTokenizer BpeTokenizer
@ -22,8 +16,8 @@ from khaosz.inference.generator import (
StreamGenerator, StreamGenerator,
BatchGenerator, BatchGenerator,
EmbeddingEncoder, EmbeddingEncoder,
GeneratorFactory
) )
from khaosz.trainer import ( from khaosz.trainer import (
Trainer, Trainer,
StrategyFactory, StrategyFactory,
@ -31,14 +25,8 @@ from khaosz.trainer import (
) )
__all__ = [ __all__ = [
"Khaosz",
"Transformer", "Transformer",
"Retriever",
"SemanticTextSplitter",
"PriorityTextSplitter",
"ModelConfig", "ModelConfig",
"TrainConfig", "TrainConfig",
@ -50,6 +38,7 @@ __all__ = [
"StreamGenerator", "StreamGenerator",
"BatchGenerator", "BatchGenerator",
"EmbeddingEncoder", "EmbeddingEncoder",
"GeneratorFactory",
"Trainer", "Trainer",
"StrategyFactory", "StrategyFactory",

View File

@ -1,138 +0,0 @@
from torch import nn
from torch import Tensor
from contextlib import contextmanager
from typing import List, Tuple, Generator, Union
from khaosz.inference.generator import (
GenerationRequest,
LoopGenerator,
StreamGenerator,
BatchGenerator,
EmbeddingEncoder
)
from khaosz.config.param_config import ModelParameter
@contextmanager
def disable_random_init():
init_functions = [
'xavier_normal_', 'xavier_uniform_',
'kaiming_normal_', 'kaiming_uniform_',
'zeros_', 'ones_', 'constant_',
'normal_', 'uniform_'
]
original_funcs = {}
for name in init_functions:
if hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try:
yield
finally:
for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
class Khaosz:
def __init__(self, model_dir: str):
with disable_random_init():
self.parameter = ModelParameter()
self.parameter.load(model_dir)
def to(self, *args, **kwargs):
self.parameter.to(*args, **kwargs)
return self
def generate(
self,
query: str,
history: List[Tuple[str, str]]=None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> str:
generator = LoopGenerator(self.parameter)
return generator.generate(
GenerationRequest(
top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
history=history,
build_prompt=True
))
def batch_generate(
self,
query: List[str],
history: List[Tuple[str, str]]=None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> List[str]:
generator = BatchGenerator(self.parameter)
return generator.generate(
GenerationRequest(
top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
history=history,
build_prompt=True
))
def stream_generate(
self,
query: str,
history: List[Tuple[str, str]]=None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
stream_generator = StreamGenerator(self.parameter)
return stream_generator.generate(
GenerationRequest(
top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
history=history,
build_prompt=True
))
def retrieve_generate(
self,
retrieved,
query: str,
history: List[Tuple[str, str]] = None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> str:
generator = LoopGenerator(self.parameter)
return generator.generate(
GenerationRequest(
top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
history=history,
system_prompt=retrieved,
build_prompt=True
))
def text_generate(
self,
query: str,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> str:
generator = LoopGenerator(self.parameter)
return generator.generate(
GenerationRequest(
top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
build_prompt=False
))
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
encoder = EmbeddingEncoder(self.parameter)
return encoder.encode(sentence)

View File

@ -72,9 +72,12 @@ class BaseModelIO:
class ModelParameter(BaseModelIO): class ModelParameter(BaseModelIO):
"""Container for model parameters with serialization capabilities.""" """Container for model parameters with serialization capabilities."""
def save(self, save_dir: Union[str, Path]): @classmethod
self.save_components(save_dir) def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]):
instance.save_components(save_dir)
def load(self, load_dir: Union[str, Path]) -> "ModelParameter": @classmethod
return self.load_components(load_dir) def load(cls, load_dir: Union[str, Path]) -> "ModelParameter":
instance = cls()
return instance.load_components(load_dir)

View File

@ -1,4 +1,5 @@
from khaosz.inference.core import ( from khaosz.inference.core import (
disable_random_init,
GeneratorCore, GeneratorCore,
EmbeddingEncoderCore, EmbeddingEncoderCore,
KVCacheManager, KVCacheManager,
@ -10,9 +11,11 @@ from khaosz.inference.generator import (
StreamGenerator, StreamGenerator,
BatchGenerator, BatchGenerator,
EmbeddingEncoder, EmbeddingEncoder,
GeneratorFactory
) )
__all__ = [ __all__ = [
"disable_random_init",
"GeneratorCore", "GeneratorCore",
"EmbeddingEncoderCore", "EmbeddingEncoderCore",
"KVCacheManager", "KVCacheManager",
@ -22,4 +25,5 @@ __all__ = [
"StreamGenerator", "StreamGenerator",
"BatchGenerator", "BatchGenerator",
"EmbeddingEncoder", "EmbeddingEncoder",
"GeneratorFactory"
] ]

View File

@ -1,5 +1,8 @@
import torch import torch
import torch.nn as nn
from torch import Tensor from torch import Tensor
from contextlib import contextmanager
from typing import Any, Callable, List, Tuple, Union, Optional, Self from typing import Any, Callable, List, Tuple, Union, Optional, Self
from khaosz.config import ModelParameter, ModelConfig from khaosz.config import ModelParameter, ModelConfig
@ -54,6 +57,26 @@ def apply_sampling_strategies(
return logits return logits
@contextmanager
def disable_random_init():
init_functions = [
'xavier_normal_', 'xavier_uniform_',
'kaiming_normal_', 'kaiming_uniform_',
'zeros_', 'ones_', 'constant_',
'normal_', 'uniform_'
]
original_funcs = {}
for name in init_functions:
if hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try:
yield
finally:
for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
class GeneratorCore: class GeneratorCore:
def __init__(self, parameter: ModelParameter): def __init__(self, parameter: ModelParameter):
self.model = parameter.model self.model = parameter.model
@ -82,10 +105,6 @@ class GeneratorCore:
return next_token_id, cache_increase return next_token_id, cache_increase
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
def generate_loop( def generate_loop(
self, self,
input_ids: Tensor, input_ids: Tensor,
@ -115,6 +134,10 @@ class GeneratorCore:
break break
return ids return ids
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class EmbeddingEncoderCore: class EmbeddingEncoderCore:
@ -203,7 +226,7 @@ class KVCacheManager:
self._kv_cache: Tuple[Tensor, Tensor] = None self._kv_cache: Tuple[Tensor, Tensor] = None
self._seq_mask: Tensor = None self._seq_mask: Tensor = None
self._initialize() self._initialize()
def _initialize(self): def _initialize(self):
k_cache = torch.zeros( k_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim), (self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),

View File

@ -9,33 +9,37 @@ from khaosz.config.param_config import ModelParameter
HistoryType = List[Tuple[str, str]] HistoryType = List[Tuple[str, str]]
def build_prompt( def build_prompt(
query: str, query: str,
init_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
history: Optional[List[Tuple[str, str]]] = None history: Optional[HistoryType] = None
) -> str: ) -> str:
"""
Build prompt in ChatML format for query and history
Args:
query(str): query string
history(Optional[List[Tuple[str, str]]]): history list of query and response
Returns:
str: prompt string in ChatML format
""" """
prompt = f"<|im_start|>system\n{init_prompt}<|im_end|>\n" if init_prompt else "" Build prompt in ChatML format for query and history.
Args:
query (str): query string.
system_prompt (Optional[str]): system prompt string.
history (Optional[HistoryType]): history list of query and response.
Returns:
str: prompt string in ChatML format.
"""
result = ""
if system_prompt:
result += f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
# (convert tuple format to ChatML) # (convert tuple format to ChatML)
if history: if history:
for user_msg, assistant_msg in history: for user_msg, assistant_msg in history:
prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n" result += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n" result += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
prompt += f"<|im_start|>user\n{query}<|im_end|>\n" result += f"<|im_start|>user\n{query}<|im_end|>\n"
prompt += "<|im_start|>assistant\n" result += "<|im_start|>assistant\n"
return prompt return result
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]: def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
""" """
@ -59,8 +63,21 @@ def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]
return new_ids_list, max_ids_len return new_ids_list, max_ids_len
@dataclass @dataclass
class GenerationRequest: class GenerationRequest:
"""
Request parameters for text generation.
Attributes:
top_k: Top-k sampling parameter.
top_p: Top-p (nucleus) sampling parameter.
temperature: Sampling temperature.
max_len: Maximum generation length.
query: Input query (string or list of strings for batch).
history: Conversation history.
system_prompt: System prompt for the conversation.
"""
top_k: int top_k: int
top_p: float top_p: float
temperature: float temperature: float
@ -70,8 +87,6 @@ class GenerationRequest:
history: Optional[Union[HistoryType, List[HistoryType]]] = None history: Optional[Union[HistoryType, List[HistoryType]]] = None
system_prompt: Optional[str] = None system_prompt: Optional[str] = None
build_prompt: bool = True
def __post_init__(self): def __post_init__(self):
if not isinstance(self.top_k, int) or self.top_k < 0: if not isinstance(self.top_k, int) or self.top_k < 0:
raise ValueError("top_k must be a non-negative integer") raise ValueError("top_k must be a non-negative integer")
@ -89,19 +104,21 @@ class LoopGenerator(GeneratorCore):
device = next(self.model.parameters()).device device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, 1, device=device) cache_manager = KVCacheManager(self.config, 1, device=device)
input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query prompt = build_prompt(request.query, request.history)
ids = self.tokenizer.encode(input_args) ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([ids], device=device, dtype=torch.long) input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids) start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval() self.model.eval()
kv_caches = cache_manager.get_kvcache() kv_caches = cache_manager.get_kvcache()
ids = self.generate_loop( ids = self.generate_loop(
input_ids, ids, request.temperature, request.top_k, request.top_p, input_ids,
ids,
request.temperature,
request.top_k,
request.top_p,
kv_caches=kv_caches, kv_caches=kv_caches,
start_pos=cur_cache_pos
) )
response = self.tokenizer.decode(ids[start_cache_pos:]) response = self.tokenizer.decode(ids[start_cache_pos:])
@ -112,16 +129,12 @@ class StreamGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter): def __init__(self, parameter: ModelParameter):
super().__init__(parameter) super().__init__(parameter)
def generate(self, request: GenerationRequest) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]: def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
if request.history is None:
request.history = []
device = next(self.model.parameters()).device device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, 1, device=device) cache_manager = KVCacheManager(self.config, 1, device=device)
input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query prompt = build_prompt(request.query, request.history)
ids = self.tokenizer.encode(input_args) ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([ids], device=device, dtype=torch.long) input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids) start_cache_pos = len(ids)
@ -141,10 +154,10 @@ class StreamGenerator(GeneratorCore):
cur_cache_pos += cache_increase cur_cache_pos += cache_increase
response = self.tokenizer.decode(ids[start_cache_pos:]) response = self.tokenizer.decode(ids[start_cache_pos:])
yield response, request.history + [(request.query, response)] yield response
if next_token_id.item() in self.tokenizer.stop_ids: if next_token_id.item() in self.tokenizer.stop_ids:
yield response + "\n", request.history + [(request.query, response)] yield response + "\n"
break break
@ -217,4 +230,36 @@ class EmbeddingEncoder(EmbeddingEncoderCore):
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]: def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
return super().encode(sentence) return super().encode(sentence)
class GeneratorFactory:
"""Factory class for creating appropriate generator instances based on request features."""
@staticmethod
def create_generator(parameter: ModelParameter, request: GenerationRequest):
"""
Create a generator based on the characteristics of GenerationRequest.
Args:
parameter: Model parameters
request: Generation request
Returns:
Subclass instance of GeneratorCore
"""
# Streaming generation detection: check stream field
if request.stream:
return StreamGenerator(parameter)
# Batch generation detection: query is a list
if isinstance(request.query, list):
return BatchGenerator(parameter)
# Default return LoopGenerator
return LoopGenerator(parameter)
@staticmethod
def create_encoder(parameter: ModelParameter):
"""Create an EmbeddingEncoder instance"""
return EmbeddingEncoder(parameter)

View File

@ -1 +0,0 @@
# init file

View File

@ -1,88 +0,0 @@
import torch
import sqlite3
import numpy as np
from torch import Tensor
from typing import Dict, List, Tuple
class Retriever:
def __init__(self, db_path=None):
self.data: Dict[str, Tensor] = {}
self.embedding_cache: Tensor = None
self.is_caculated: bool = False
if db_path is not None:
self.load(db_path)
def retrieve(self, query: Tensor, top_k: int) -> List[Tuple[str, float]]:
if not self.data:
return []
query = query.flatten().unsqueeze(1) # [dim, 1]
norm_embeddings = self._embeddings.to(
device=query.device,
dtype=query.dtype
) # [n_vectors, dim]
sim_scores = torch.matmul(norm_embeddings, query).squeeze() # [n_vectors]
top_k = min(top_k, len(self.data))
indices = sim_scores.topk(top_k).indices
keys = list(self.data.keys())
return [(keys[i], sim_scores[i].item()) for i in indices]
def add_vector(self, key: str, vector_data: Tensor):
self.is_caculated = False
self.data[key] = vector_data.flatten().float().cpu()
def delete_vector(self, key: str):
self.is_caculated = False
self.data.pop(key, None)
def save(self, db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
self._init_db(cursor)
cursor.execute('DELETE FROM vectors')
for item, vec in self.data.items():
vec_bytes = vec.numpy().tobytes()
cursor.execute('INSERT OR REPLACE INTO vectors (key, vector) VALUES (?, ?)',
(item, vec_bytes))
conn.commit()
conn.close()
def load(self, db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
self._init_db(cursor)
cursor.execute('SELECT key, vector FROM vectors')
rows = cursor.fetchall()
self.data = {}
for row in rows:
key, vec_bytes = row
vec_numpy = np.frombuffer(vec_bytes, dtype=np.float32).copy()
vec = torch.from_numpy(vec_numpy)
self.data[key] = vec
conn.close()
def _init_db(self,cursor: sqlite3.Cursor):
# Create table if not exists (in case loading from a new database)
cursor.execute('''
CREATE TABLE IF NOT EXISTS vectors (
id INTEGER PRIMARY KEY AUTOINCREMENT,
key TEXT UNIQUE NOT NULL,
vector BLOB NOT NULL
)''')
@property
def _embeddings(self) -> Tensor:
if not self.is_caculated:
embeddings = torch.stack(list(self.data.values()))
norm_embeddings = embeddings / torch.norm(embeddings, dim=-1, keepdim=True)
self.embedding_cache = norm_embeddings
return self.embedding_cache

View File

@ -1,127 +0,0 @@
import re
import torch
import torch.nn.functional as F
from abc import ABC, abstractmethod
from torch import Tensor
from typing import List, Callable, Optional
class BaseTextSplitter(ABC):
def __init__(
self,
max_len: int = 512,
chunk_overlap: int = 0,
):
if max_len <= 0:
raise ValueError("max_len must be > 0")
if chunk_overlap < 0:
raise ValueError("chunk_overlap must be >= 0")
self.max_len = max_len
self.chunk_overlap = chunk_overlap
@abstractmethod
def split(self, text: str, **kwargs) -> List[str]:
raise NotImplementedError
def preprocess(self, text: str) -> str:
return text.strip()
def postprocess(self, chunks: List[str]) -> List[str]:
return [chunk.strip() for chunk in chunks if chunk.strip()]
class PriorityTextSplitter(BaseTextSplitter):
def __init__(
self,
separators: List[str],
max_len: int = 512,
chunk_overlap: int = 0,
):
super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
if not separators:
raise ValueError("separators must be a non-empty list")
self.separators = separators
def split(self, text: str) -> List[str]:
text = self.preprocess(text)
for sep in self.separators:
parts = text.split(sep)
valid_parts = [p.strip() for p in parts if p.strip()]
if len(valid_parts) > 1:
return self.postprocess(valid_parts)
return [text]
class SemanticTextSplitter(BaseTextSplitter):
DEFAULT_PATTERN = r'(?<=[。!?!?])(?=(?:[^"\'‘’“”]*["\'‘’“”][^"\'‘’“”]*["\'‘’“”])*[^"\'‘’“”]*$)'
def __init__(
self,
embedding_func: Callable[[List[str]], List[Tensor]],
pattern: Optional[str] = None,
max_len: int = 512,
chunk_overlap: int = 0,
):
super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
if not callable(embedding_func):
raise TypeError("embedding_func must be callable")
self.embedding_func = embedding_func
self.pattern = pattern or SemanticTextSplitter.DEFAULT_PATTERN
def split(
self,
text: str,
threshold: float = 0.5,
window_size: int = 1,
) -> List[str]:
text = self.preprocess(text)
sentences = [s.strip() for s in re.split(self.pattern, text) if s.strip()]
if len(sentences) <= 1:
return self.postprocess(sentences)
try:
sentence_embs = self.embedding_func(sentences)
except Exception as e:
raise RuntimeError(f"Embedding generation failed: {e}")
if len(sentence_embs) != len(sentences):
raise ValueError("Embedding function must return one vector per sentence")
chunks = []
emb_tensor = torch.stack(sentence_embs) # shape: [N, D]
current_chunk: List[str] = [sentences[0]]
for i in range(1, len(sentences)):
start_prev = max(0, i - window_size)
end_prev = i
start_next = i
end_next = min(len(sentences), i + window_size)
prev_window_emb = emb_tensor[start_prev:end_prev].mean(dim=0)
next_window_emb = emb_tensor[start_next:end_next].mean(dim=0)
similarity = F.cosine_similarity(
prev_window_emb.unsqueeze(0),
next_window_emb.unsqueeze(0),
dim=1
).item()
dynamic_threshold = max(threshold * (1 - 0.03 * (end_next - start_prev)), 0.2)
if similarity < dynamic_threshold:
chunks.append(" ".join(current_chunk))
overlap_start = max(0, len(current_chunk) - self.chunk_overlap)
current_chunk = current_chunk[overlap_start:]
current_chunk.append(sentences[i])
else:
current_chunk.append(sentences[i])
if current_chunk:
chunks.append(" ".join(current_chunk))
return self.postprocess(chunks)

View File

@ -54,7 +54,7 @@ def test_env(request: pytest.FixtureRequest):
def test_model_parameter(test_env): def test_model_parameter(test_env):
save_dir = os.path.join(test_env["test_dir"], "save") save_dir = os.path.join(test_env["test_dir"], "save")
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"]) model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])
model_param.save(save_dir) ModelParameter.save(model_param, save_dir)
assert os.path.exists(os.path.join(save_dir, "model.safetensors")) assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
assert os.path.exists(os.path.join(save_dir, "tokenizer.json")) assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))

View File

@ -1,45 +1,10 @@
import torch import torch
import json import json
import torch
import argparse import argparse
from khaosz import Khaosz from khaosz.config.param_config import ModelParameter
from typing import List from khaosz.inference.generator import BatchGenerator, GenerationRequest
from tqdm import tqdm from khaosz.inference.core import disable_random_init
def batch_generate(
model: Khaosz,
query: List[str],
temperature: float,
top_k: int,
top_p: float,
batch_size: int,
) -> List:
assert batch_size > 0
sorted_query = sorted(query, key=lambda x: len(x), reverse=True)
original_indices = {query: idx for idx, query in enumerate(query)}
responses = [None] * len(query)
total_batches = (len(sorted_query) + batch_size - 1) // batch_size
for i in tqdm(range(0, total_batches * batch_size, batch_size), desc="Generating responses"):
batch_query = sorted_query[i: min(i + batch_size, len(query))]
if not isinstance(batch_query, list):
batch_query = [batch_query]
batch_responses = model.batch_generate(
query=batch_query,
temperature=temperature,
top_k=top_k,
top_p=top_p
)
for query, response in zip(batch_query, batch_responses):
original_idx = original_indices[query]
responses[original_idx] = response
return responses
def processor( def processor(
@ -53,24 +18,31 @@ def processor(
question_key: str, question_key: str,
response_key: str, response_key: str,
): ):
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16) with disable_random_init():
param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16)
generator = BatchGenerator(param)
with open(input_json_file, "r", encoding='utf-8') as f: with open(input_json_file, "r", encoding='utf-8') as f:
input_data = [json.loads(line) for line in f] input_data = [json.loads(line) for line in f]
query = [item[question_key] for item in input_data]
queries = [item[question_key] for item in input_data]
responses = batch_generate(
model=model, request = GenerationRequest(
query=query, query=queries,
temperature=temperature, temperature=temperature,
top_k=top_k,
top_p=top_p, top_p=top_p,
batch_size=batch_size top_k=top_k,
max_len=param.config.max_len,
history=None,
system_prompt=None,
) )
# Write output in JSONL format responses = generator.generate(request)
with open(output_json_file, "w", encoding='utf-8') as f: with open(output_json_file, "w", encoding='utf-8') as f:
for query, response in zip(query, responses): for query, response in zip(queries, responses):
output_item = {question_key: query, response_key: response} output_item = {question_key: query, response_key: response}
f.write(json.dumps(output_item, ensure_ascii=False) + '\n') f.write(json.dumps(output_item, ensure_ascii=False) + '\n')
@ -89,4 +61,6 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.") parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.")
args = parser.parse_args() args = parser.parse_args()
processor(**vars(args))
with torch.inference_mode():
processor(**vars(args))

View File

@ -6,7 +6,9 @@ import argparse
import tqdm import tqdm
from torch import Tensor from torch import Tensor
from khaosz import Khaosz from khaosz.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init
def compute_perplexity( def compute_perplexity(
model: nn.Module, model: nn.Module,
@ -45,22 +47,23 @@ def process_file(
batch_size: int, batch_size: int,
text_key: str text_key: str
): ):
model = Khaosz(model_dir).to(device="cuda", dtype=torch.bfloat16) with disable_random_init():
tokenizer = model.parameter.tokenizer param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16)
model = param.model
tokenizer = param.tokenizer
with open(input_file, "r", encoding='utf-8') as f: with open(input_file, "r", encoding='utf-8') as f:
input_data = [json.loads(line) for line in f] input_data = [json.loads(line) for line in f]
texts = [item[text_key] for item in input_data] texts = [item[text_key] for item in input_data]
encoded_texts = [tokenizer.encode(text) for text in texts] encoded_texts = [tokenizer.encode(text) for text in texts]
output_data = [] output_data = []
for i in tqdm(range(0, len(encoded_texts), batch_size), desc="Computing perplexity"): for i in tqdm(range(0, len(encoded_texts), batch_size), desc="Computing perplexity"):
batch_encoded = encoded_texts[i:i + batch_size] batch_encoded = encoded_texts[i:i + batch_size]
batch_texts = texts[i:i + batch_size] batch_texts = texts[i:i + batch_size]
# Pad sequences to the same length (left padding)
max_len = max(len(seq) for seq in batch_encoded) max_len = max(len(seq) for seq in batch_encoded)
padded_ids = [] padded_ids = []
masks = [] masks = []
@ -74,10 +77,7 @@ def process_file(
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long) input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool) input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
perplexity = compute_perplexity(model, input_ids, input_mask)
# Compute perplexity
with torch.inference_mode():
perplexity = compute_perplexity(model.parameter.model, input_ids, input_mask)
for text, ppl in zip(batch_texts, perplexity): for text, ppl in zip(batch_texts, perplexity):
output_data.append({text_key: text, "ppl": float(ppl.item())}) output_data.append({text_key: text, "ppl": float(ppl.item())})
@ -87,16 +87,14 @@ def process_file(
f.write(json.dumps(item, ensure_ascii=False) + '\n') f.write(json.dumps(item, ensure_ascii=False) + '\n')
def main(): if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.") parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.") parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
parser.add_argument("--input_file", type=str, required=True, help="Path to the input file.") parser.add_argument("--input_file", type=str, required=True, help="Path to the input file.")
parser.add_argument("--output_file", type=str, required=True, help="Path to the output file.") parser.add_argument("--output_file", type=str, required=True, help="Path to the output file.")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation.") parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation.")
parser.add_argument("--text_key", type=str, default="text", help="Key for the text field in the input data.") parser.add_argument("--text_key", type=str, default="text", help="Key for the text field in the input data.")
args = parser.parse_args() args = parser.parse_args()
process_file(**vars(args))
if __name__ == "__main__": with torch.inference_mode():
main() process_file(**vars(args))

View File

@ -16,7 +16,7 @@ def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train the Transformer model.") parser = argparse.ArgumentParser(description="Train the Transformer model.")
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.") parser.add_argument("--train_type", type=str, required=True, choices=["seq", "sft", "dpo"], help="Train type.")
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.") parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.") parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
@ -67,18 +67,14 @@ def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler
return SchedulerFactory.load(optimizer, **kwargs) return SchedulerFactory.load(optimizer, **kwargs)
def prepare_checkpoint(model: nn.Module) -> dict: def prepare_checkpoint(model: nn.Module) -> dict:
if isinstance(model, torch.nn.parallel.DistributedDataParallel): return model.module.state_dict()
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
return state_dict
def train( def train(
train_type: str, train_type: str,
param_path: str, param_path: str,
data_root_path: str, data_root_path: str,
max_lr: int, max_lr: float,
n_epoch: int, n_epoch: int,
batch_size: int, batch_size: int,
start_epoch: int, start_epoch: int,
@ -104,8 +100,7 @@ def train(
assert train_type in ["seq", "sft", "dpo"] assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path) assert os.path.exists(param_path)
parameter = ModelParameter() parameter = ModelParameter.load(param_path)
parameter.load(param_path)
if window_size is None: if window_size is None:
window_size = parameter.config.max_len window_size = parameter.config.max_len