feat: 增加推理部分工厂模式
This commit is contained in:
parent
980299cd54
commit
c01791ff54
|
|
@ -1,27 +1,35 @@
|
|||
import os
|
||||
import torch
|
||||
from khaosz import Khaosz
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
from khaosz.inference.core import disable_random_init
|
||||
from khaosz.inference.generator import LoopGenerator, GenerationRequest
|
||||
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(
|
||||
os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
def generate_text():
|
||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||
|
||||
|
||||
with disable_random_init():
|
||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||
param = ModelParameter.load(model_dir)
|
||||
|
||||
param.to(device='cuda', dtype=torch.bfloat16)
|
||||
query = input(">> ")
|
||||
|
||||
response = model.text_generate(
|
||||
query=query,
|
||||
request = GenerationRequest(
|
||||
query=query,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50
|
||||
top_k=50,
|
||||
max_len=param.config.max_len,
|
||||
history=None,
|
||||
system_prompt=None,
|
||||
)
|
||||
generator = LoopGenerator(param)
|
||||
response = generator.generate(request)
|
||||
|
||||
print(response)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_text()
|
||||
|
|
@ -1,22 +1,31 @@
|
|||
import os
|
||||
import torch
|
||||
from khaosz import Khaosz
|
||||
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
from khaosz.inference.core import disable_random_init
|
||||
from khaosz.inference.generator import BatchGenerator, GenerationRequest
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(
|
||||
os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
def batch_generate():
|
||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||
with disable_random_init():
|
||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||
param = ModelParameter.load(model_dir)
|
||||
|
||||
param.to(device='cuda', dtype=torch.bfloat16)
|
||||
generator = BatchGenerator(param)
|
||||
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
|
||||
|
||||
responses = model.batch_generate(
|
||||
request = GenerationRequest(
|
||||
query=inputs,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50
|
||||
top_k=50,
|
||||
max_len=param.config.max_len,
|
||||
history=None,
|
||||
system_prompt=None,
|
||||
)
|
||||
responses = generator.generate(request)
|
||||
|
||||
for q, r in zip(inputs, responses):
|
||||
print((q, r))
|
||||
|
|
|
|||
|
|
@ -1,43 +0,0 @@
|
|||
import os
|
||||
import torch
|
||||
from khaosz import Khaosz, SemanticTextSplitter, Retriever
|
||||
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(
|
||||
os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||
context_path = os.path.join(PROJECT_ROOT, "README.md")
|
||||
|
||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||
spliter = SemanticTextSplitter(model.encode)
|
||||
retriever = Retriever()
|
||||
text = open(context_path, "r", encoding="utf-8").read()
|
||||
|
||||
res = spliter.split(text, threshold=0.8, window_size=1)
|
||||
# print(("\n" + "+"*100 + "\n").join(res))
|
||||
|
||||
res_embs = model.encode(res)
|
||||
for sentence, emb in zip(res, res_embs):
|
||||
retriever.add_vector(sentence, emb)
|
||||
|
||||
retrive_top_k = 5
|
||||
query = "作者设计了一个怎样的模型"
|
||||
emb_query = model.encode(query)
|
||||
retrieved = retriever.retrieve(emb_query, retrive_top_k)
|
||||
retrieved_content = "\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)])
|
||||
|
||||
retrive_response = model.retrieve_generate(
|
||||
retrieved=retrieved_content,
|
||||
query=query,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50
|
||||
)
|
||||
|
||||
print("retrieve content:")
|
||||
print(retrieved_content)
|
||||
|
||||
print("\n\nretrive generate:")
|
||||
print(retrive_response)
|
||||
|
|
@ -1,14 +1,21 @@
|
|||
import os
|
||||
import torch
|
||||
from khaosz import Khaosz
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
from khaosz.inference.core import disable_random_init
|
||||
from khaosz.inference.generator import StreamGenerator, GenerationRequest
|
||||
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(
|
||||
os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
def chat():
|
||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||
|
||||
with disable_random_init():
|
||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||
param = ModelParameter.load(model_dir)
|
||||
|
||||
param.to(device='cuda', dtype=torch.bfloat16)
|
||||
generator = StreamGenerator(param)
|
||||
|
||||
history = []
|
||||
while True:
|
||||
|
|
@ -16,17 +23,27 @@ def chat():
|
|||
if query == "!exit":
|
||||
break
|
||||
|
||||
response_size = 0
|
||||
for response, history in model.stream_generate(
|
||||
query=query,
|
||||
history=history,
|
||||
request = GenerationRequest(
|
||||
query=query,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50
|
||||
):
|
||||
top_k=50,
|
||||
max_len=param.config.max_len,
|
||||
history=history,
|
||||
system_prompt=None,
|
||||
)
|
||||
|
||||
response_size = 0
|
||||
full_response = ""
|
||||
for response in generator.generate(request):
|
||||
# response is the cumulative response up to current token
|
||||
print(response[response_size:], end="", flush=True)
|
||||
response_size = len(response)
|
||||
|
||||
full_response = response
|
||||
|
||||
# After generation, update history
|
||||
history.append((query, full_response.strip()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
chat()
|
||||
|
|
@ -1,17 +1,11 @@
|
|||
__version__ = "1.3.2"
|
||||
__author__ = "ViperEkura"
|
||||
|
||||
from khaosz.api import Khaosz
|
||||
from khaosz.config import (
|
||||
ModelConfig,
|
||||
TrainConfig,
|
||||
)
|
||||
from khaosz.model.transformer import Transformer
|
||||
from khaosz.utils.retriever import Retriever
|
||||
from khaosz.utils.splitter import (
|
||||
SemanticTextSplitter,
|
||||
PriorityTextSplitter
|
||||
)
|
||||
from khaosz.data import (
|
||||
DatasetLoader,
|
||||
BpeTokenizer
|
||||
|
|
@ -22,8 +16,8 @@ from khaosz.inference.generator import (
|
|||
StreamGenerator,
|
||||
BatchGenerator,
|
||||
EmbeddingEncoder,
|
||||
GeneratorFactory
|
||||
)
|
||||
|
||||
from khaosz.trainer import (
|
||||
Trainer,
|
||||
StrategyFactory,
|
||||
|
|
@ -31,14 +25,8 @@ from khaosz.trainer import (
|
|||
)
|
||||
|
||||
__all__ = [
|
||||
"Khaosz",
|
||||
|
||||
"Transformer",
|
||||
|
||||
"Retriever",
|
||||
"SemanticTextSplitter",
|
||||
"PriorityTextSplitter",
|
||||
|
||||
"ModelConfig",
|
||||
"TrainConfig",
|
||||
|
||||
|
|
@ -50,6 +38,7 @@ __all__ = [
|
|||
"StreamGenerator",
|
||||
"BatchGenerator",
|
||||
"EmbeddingEncoder",
|
||||
"GeneratorFactory",
|
||||
|
||||
"Trainer",
|
||||
"StrategyFactory",
|
||||
|
|
|
|||
138
khaosz/api.py
138
khaosz/api.py
|
|
@ -1,138 +0,0 @@
|
|||
from torch import nn
|
||||
from torch import Tensor
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Tuple, Generator, Union
|
||||
|
||||
from khaosz.inference.generator import (
|
||||
GenerationRequest,
|
||||
LoopGenerator,
|
||||
StreamGenerator,
|
||||
BatchGenerator,
|
||||
EmbeddingEncoder
|
||||
)
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
|
||||
@contextmanager
|
||||
def disable_random_init():
|
||||
init_functions = [
|
||||
'xavier_normal_', 'xavier_uniform_',
|
||||
'kaiming_normal_', 'kaiming_uniform_',
|
||||
'zeros_', 'ones_', 'constant_',
|
||||
'normal_', 'uniform_'
|
||||
]
|
||||
original_funcs = {}
|
||||
for name in init_functions:
|
||||
if hasattr(nn.init, name):
|
||||
original_funcs[name] = getattr(nn.init, name)
|
||||
setattr(nn.init, name, lambda *args, **kwargs: None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for name, orig_func in original_funcs.items():
|
||||
setattr(nn.init, name, orig_func)
|
||||
|
||||
|
||||
class Khaosz:
|
||||
def __init__(self, model_dir: str):
|
||||
with disable_random_init():
|
||||
self.parameter = ModelParameter()
|
||||
self.parameter.load(model_dir)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
self.parameter.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def generate(
|
||||
self,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]]=None,
|
||||
temperature: float=0.8,
|
||||
top_k: int=50,
|
||||
top_p: float=0.95,
|
||||
) -> str:
|
||||
generator = LoopGenerator(self.parameter)
|
||||
return generator.generate(
|
||||
GenerationRequest(
|
||||
top_k, top_p, temperature,
|
||||
self.parameter.config.max_len,
|
||||
query=query,
|
||||
history=history,
|
||||
build_prompt=True
|
||||
))
|
||||
|
||||
def batch_generate(
|
||||
self,
|
||||
query: List[str],
|
||||
history: List[Tuple[str, str]]=None,
|
||||
temperature: float=0.8,
|
||||
top_k: int=50,
|
||||
top_p: float=0.95,
|
||||
) -> List[str]:
|
||||
generator = BatchGenerator(self.parameter)
|
||||
return generator.generate(
|
||||
GenerationRequest(
|
||||
top_k, top_p, temperature,
|
||||
self.parameter.config.max_len,
|
||||
query=query,
|
||||
history=history,
|
||||
build_prompt=True
|
||||
))
|
||||
|
||||
def stream_generate(
|
||||
self,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]]=None,
|
||||
temperature: float=0.8,
|
||||
top_k: int=50,
|
||||
top_p: float=0.95,
|
||||
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
|
||||
stream_generator = StreamGenerator(self.parameter)
|
||||
return stream_generator.generate(
|
||||
GenerationRequest(
|
||||
top_k, top_p, temperature,
|
||||
self.parameter.config.max_len,
|
||||
query=query,
|
||||
history=history,
|
||||
build_prompt=True
|
||||
))
|
||||
|
||||
def retrieve_generate(
|
||||
self,
|
||||
retrieved,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]] = None,
|
||||
temperature: float=0.8,
|
||||
top_k: int=50,
|
||||
top_p: float=0.95,
|
||||
) -> str:
|
||||
generator = LoopGenerator(self.parameter)
|
||||
return generator.generate(
|
||||
GenerationRequest(
|
||||
top_k, top_p, temperature,
|
||||
self.parameter.config.max_len,
|
||||
query=query,
|
||||
history=history,
|
||||
system_prompt=retrieved,
|
||||
build_prompt=True
|
||||
))
|
||||
|
||||
def text_generate(
|
||||
self,
|
||||
query: str,
|
||||
temperature: float=0.8,
|
||||
top_k: int=50,
|
||||
top_p: float=0.95,
|
||||
) -> str:
|
||||
generator = LoopGenerator(self.parameter)
|
||||
return generator.generate(
|
||||
GenerationRequest(
|
||||
top_k, top_p, temperature,
|
||||
self.parameter.config.max_len,
|
||||
query=query,
|
||||
build_prompt=False
|
||||
))
|
||||
|
||||
|
||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||
encoder = EmbeddingEncoder(self.parameter)
|
||||
return encoder.encode(sentence)
|
||||
|
|
@ -72,9 +72,12 @@ class BaseModelIO:
|
|||
class ModelParameter(BaseModelIO):
|
||||
"""Container for model parameters with serialization capabilities."""
|
||||
|
||||
def save(self, save_dir: Union[str, Path]):
|
||||
self.save_components(save_dir)
|
||||
@classmethod
|
||||
def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]):
|
||||
instance.save_components(save_dir)
|
||||
|
||||
def load(self, load_dir: Union[str, Path]) -> "ModelParameter":
|
||||
return self.load_components(load_dir)
|
||||
@classmethod
|
||||
def load(cls, load_dir: Union[str, Path]) -> "ModelParameter":
|
||||
instance = cls()
|
||||
return instance.load_components(load_dir)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from khaosz.inference.core import (
|
||||
disable_random_init,
|
||||
GeneratorCore,
|
||||
EmbeddingEncoderCore,
|
||||
KVCacheManager,
|
||||
|
|
@ -10,9 +11,11 @@ from khaosz.inference.generator import (
|
|||
StreamGenerator,
|
||||
BatchGenerator,
|
||||
EmbeddingEncoder,
|
||||
GeneratorFactory
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"disable_random_init",
|
||||
"GeneratorCore",
|
||||
"EmbeddingEncoderCore",
|
||||
"KVCacheManager",
|
||||
|
|
@ -22,4 +25,5 @@ __all__ = [
|
|||
"StreamGenerator",
|
||||
"BatchGenerator",
|
||||
"EmbeddingEncoder",
|
||||
"GeneratorFactory"
|
||||
]
|
||||
|
|
@ -1,5 +1,8 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch import Tensor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, List, Tuple, Union, Optional, Self
|
||||
from khaosz.config import ModelParameter, ModelConfig
|
||||
|
||||
|
|
@ -54,6 +57,26 @@ def apply_sampling_strategies(
|
|||
return logits
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_random_init():
|
||||
init_functions = [
|
||||
'xavier_normal_', 'xavier_uniform_',
|
||||
'kaiming_normal_', 'kaiming_uniform_',
|
||||
'zeros_', 'ones_', 'constant_',
|
||||
'normal_', 'uniform_'
|
||||
]
|
||||
original_funcs = {}
|
||||
for name in init_functions:
|
||||
if hasattr(nn.init, name):
|
||||
original_funcs[name] = getattr(nn.init, name)
|
||||
setattr(nn.init, name, lambda *args, **kwargs: None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for name, orig_func in original_funcs.items():
|
||||
setattr(nn.init, name, orig_func)
|
||||
|
||||
|
||||
class GeneratorCore:
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
self.model = parameter.model
|
||||
|
|
@ -82,10 +105,6 @@ class GeneratorCore:
|
|||
|
||||
return next_token_id, cache_increase
|
||||
|
||||
def to(self, *args, **kargs) -> Self:
|
||||
self.model.to(*args, **kargs)
|
||||
return self
|
||||
|
||||
def generate_loop(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
|
|
@ -115,6 +134,10 @@ class GeneratorCore:
|
|||
break
|
||||
|
||||
return ids
|
||||
|
||||
def to(self, *args, **kargs) -> Self:
|
||||
self.model.to(*args, **kargs)
|
||||
return self
|
||||
|
||||
|
||||
class EmbeddingEncoderCore:
|
||||
|
|
@ -203,7 +226,7 @@ class KVCacheManager:
|
|||
self._kv_cache: Tuple[Tensor, Tensor] = None
|
||||
self._seq_mask: Tensor = None
|
||||
self._initialize()
|
||||
|
||||
|
||||
def _initialize(self):
|
||||
k_cache = torch.zeros(
|
||||
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
|
||||
|
|
|
|||
|
|
@ -9,33 +9,37 @@ from khaosz.config.param_config import ModelParameter
|
|||
HistoryType = List[Tuple[str, str]]
|
||||
|
||||
def build_prompt(
|
||||
query: str,
|
||||
init_prompt: Optional[str] = None,
|
||||
history: Optional[List[Tuple[str, str]]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Build prompt in ChatML format for query and history
|
||||
|
||||
Args:
|
||||
query(str): query string
|
||||
history(Optional[List[Tuple[str, str]]]): history list of query and response
|
||||
|
||||
Returns:
|
||||
str: prompt string in ChatML format
|
||||
|
||||
query: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
history: Optional[HistoryType] = None
|
||||
) -> str:
|
||||
"""
|
||||
prompt = f"<|im_start|>system\n{init_prompt}<|im_end|>\n" if init_prompt else ""
|
||||
|
||||
Build prompt in ChatML format for query and history.
|
||||
|
||||
Args:
|
||||
query (str): query string.
|
||||
system_prompt (Optional[str]): system prompt string.
|
||||
history (Optional[HistoryType]): history list of query and response.
|
||||
|
||||
Returns:
|
||||
str: prompt string in ChatML format.
|
||||
"""
|
||||
result = ""
|
||||
|
||||
if system_prompt:
|
||||
result += f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
|
||||
|
||||
# (convert tuple format to ChatML)
|
||||
if history:
|
||||
for user_msg, assistant_msg in history:
|
||||
prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
|
||||
prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
|
||||
|
||||
prompt += f"<|im_start|>user\n{query}<|im_end|>\n"
|
||||
prompt += "<|im_start|>assistant\n"
|
||||
|
||||
return prompt
|
||||
result += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
|
||||
result += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
|
||||
|
||||
result += f"<|im_start|>user\n{query}<|im_end|>\n"
|
||||
result += "<|im_start|>assistant\n"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
|
||||
"""
|
||||
|
|
@ -59,8 +63,21 @@ def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]
|
|||
|
||||
return new_ids_list, max_ids_len
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationRequest:
|
||||
"""
|
||||
Request parameters for text generation.
|
||||
|
||||
Attributes:
|
||||
top_k: Top-k sampling parameter.
|
||||
top_p: Top-p (nucleus) sampling parameter.
|
||||
temperature: Sampling temperature.
|
||||
max_len: Maximum generation length.
|
||||
query: Input query (string or list of strings for batch).
|
||||
history: Conversation history.
|
||||
system_prompt: System prompt for the conversation.
|
||||
"""
|
||||
top_k: int
|
||||
top_p: float
|
||||
temperature: float
|
||||
|
|
@ -70,8 +87,6 @@ class GenerationRequest:
|
|||
history: Optional[Union[HistoryType, List[HistoryType]]] = None
|
||||
system_prompt: Optional[str] = None
|
||||
|
||||
build_prompt: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.top_k, int) or self.top_k < 0:
|
||||
raise ValueError("top_k must be a non-negative integer")
|
||||
|
|
@ -89,19 +104,21 @@ class LoopGenerator(GeneratorCore):
|
|||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||
|
||||
input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query
|
||||
ids = self.tokenizer.encode(input_args)
|
||||
prompt = build_prompt(request.query, request.history)
|
||||
ids = self.tokenizer.encode(prompt)
|
||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||
|
||||
start_cache_pos = len(ids)
|
||||
cur_cache_pos = 0
|
||||
self.model.eval()
|
||||
kv_caches = cache_manager.get_kvcache()
|
||||
|
||||
ids = self.generate_loop(
|
||||
input_ids, ids, request.temperature, request.top_k, request.top_p,
|
||||
input_ids,
|
||||
ids,
|
||||
request.temperature,
|
||||
request.top_k,
|
||||
request.top_p,
|
||||
kv_caches=kv_caches,
|
||||
start_pos=cur_cache_pos
|
||||
)
|
||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||
|
||||
|
|
@ -112,16 +129,12 @@ class StreamGenerator(GeneratorCore):
|
|||
def __init__(self, parameter: ModelParameter):
|
||||
super().__init__(parameter)
|
||||
|
||||
def generate(self, request: GenerationRequest) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
|
||||
|
||||
if request.history is None:
|
||||
request.history = []
|
||||
|
||||
def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
|
||||
device = next(self.model.parameters()).device
|
||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||
|
||||
input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query
|
||||
ids = self.tokenizer.encode(input_args)
|
||||
prompt = build_prompt(request.query, request.history)
|
||||
ids = self.tokenizer.encode(prompt)
|
||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||
|
||||
start_cache_pos = len(ids)
|
||||
|
|
@ -141,10 +154,10 @@ class StreamGenerator(GeneratorCore):
|
|||
cur_cache_pos += cache_increase
|
||||
|
||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||
yield response, request.history + [(request.query, response)]
|
||||
yield response
|
||||
|
||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||
yield response + "\n", request.history + [(request.query, response)]
|
||||
yield response + "\n"
|
||||
break
|
||||
|
||||
|
||||
|
|
@ -217,4 +230,36 @@ class EmbeddingEncoder(EmbeddingEncoderCore):
|
|||
|
||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||
return super().encode(sentence)
|
||||
|
||||
|
||||
class GeneratorFactory:
|
||||
"""Factory class for creating appropriate generator instances based on request features."""
|
||||
|
||||
@staticmethod
|
||||
def create_generator(parameter: ModelParameter, request: GenerationRequest):
|
||||
"""
|
||||
Create a generator based on the characteristics of GenerationRequest.
|
||||
Args:
|
||||
parameter: Model parameters
|
||||
request: Generation request
|
||||
|
||||
Returns:
|
||||
Subclass instance of GeneratorCore
|
||||
"""
|
||||
|
||||
# Streaming generation detection: check stream field
|
||||
if request.stream:
|
||||
return StreamGenerator(parameter)
|
||||
|
||||
# Batch generation detection: query is a list
|
||||
if isinstance(request.query, list):
|
||||
return BatchGenerator(parameter)
|
||||
|
||||
# Default return LoopGenerator
|
||||
return LoopGenerator(parameter)
|
||||
|
||||
@staticmethod
|
||||
def create_encoder(parameter: ModelParameter):
|
||||
"""Create an EmbeddingEncoder instance"""
|
||||
return EmbeddingEncoder(parameter)
|
||||
|
||||
|
|
@ -1 +0,0 @@
|
|||
# init file
|
||||
|
|
@ -1,88 +0,0 @@
|
|||
import torch
|
||||
import sqlite3
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
class Retriever:
|
||||
def __init__(self, db_path=None):
|
||||
self.data: Dict[str, Tensor] = {}
|
||||
self.embedding_cache: Tensor = None
|
||||
self.is_caculated: bool = False
|
||||
|
||||
if db_path is not None:
|
||||
self.load(db_path)
|
||||
|
||||
def retrieve(self, query: Tensor, top_k: int) -> List[Tuple[str, float]]:
|
||||
if not self.data:
|
||||
return []
|
||||
|
||||
query = query.flatten().unsqueeze(1) # [dim, 1]
|
||||
norm_embeddings = self._embeddings.to(
|
||||
device=query.device,
|
||||
dtype=query.dtype
|
||||
) # [n_vectors, dim]
|
||||
sim_scores = torch.matmul(norm_embeddings, query).squeeze() # [n_vectors]
|
||||
|
||||
top_k = min(top_k, len(self.data))
|
||||
indices = sim_scores.topk(top_k).indices
|
||||
keys = list(self.data.keys())
|
||||
|
||||
return [(keys[i], sim_scores[i].item()) for i in indices]
|
||||
|
||||
def add_vector(self, key: str, vector_data: Tensor):
|
||||
self.is_caculated = False
|
||||
self.data[key] = vector_data.flatten().float().cpu()
|
||||
|
||||
def delete_vector(self, key: str):
|
||||
self.is_caculated = False
|
||||
self.data.pop(key, None)
|
||||
|
||||
def save(self, db_path):
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
self._init_db(cursor)
|
||||
cursor.execute('DELETE FROM vectors')
|
||||
|
||||
for item, vec in self.data.items():
|
||||
vec_bytes = vec.numpy().tobytes()
|
||||
cursor.execute('INSERT OR REPLACE INTO vectors (key, vector) VALUES (?, ?)',
|
||||
(item, vec_bytes))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def load(self, db_path):
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
self._init_db(cursor)
|
||||
cursor.execute('SELECT key, vector FROM vectors')
|
||||
rows = cursor.fetchall()
|
||||
self.data = {}
|
||||
|
||||
for row in rows:
|
||||
key, vec_bytes = row
|
||||
vec_numpy = np.frombuffer(vec_bytes, dtype=np.float32).copy()
|
||||
vec = torch.from_numpy(vec_numpy)
|
||||
self.data[key] = vec
|
||||
|
||||
conn.close()
|
||||
|
||||
def _init_db(self,cursor: sqlite3.Cursor):
|
||||
# Create table if not exists (in case loading from a new database)
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS vectors (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
key TEXT UNIQUE NOT NULL,
|
||||
vector BLOB NOT NULL
|
||||
)''')
|
||||
|
||||
@property
|
||||
def _embeddings(self) -> Tensor:
|
||||
if not self.is_caculated:
|
||||
embeddings = torch.stack(list(self.data.values()))
|
||||
norm_embeddings = embeddings / torch.norm(embeddings, dim=-1, keepdim=True)
|
||||
self.embedding_cache = norm_embeddings
|
||||
|
||||
return self.embedding_cache
|
||||
|
|
@ -1,127 +0,0 @@
|
|||
import re
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from torch import Tensor
|
||||
from typing import List, Callable, Optional
|
||||
|
||||
|
||||
class BaseTextSplitter(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
max_len: int = 512,
|
||||
chunk_overlap: int = 0,
|
||||
):
|
||||
if max_len <= 0:
|
||||
raise ValueError("max_len must be > 0")
|
||||
if chunk_overlap < 0:
|
||||
raise ValueError("chunk_overlap must be >= 0")
|
||||
|
||||
self.max_len = max_len
|
||||
self.chunk_overlap = chunk_overlap
|
||||
|
||||
@abstractmethod
|
||||
def split(self, text: str, **kwargs) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def preprocess(self, text: str) -> str:
|
||||
return text.strip()
|
||||
|
||||
def postprocess(self, chunks: List[str]) -> List[str]:
|
||||
return [chunk.strip() for chunk in chunks if chunk.strip()]
|
||||
|
||||
|
||||
class PriorityTextSplitter(BaseTextSplitter):
|
||||
def __init__(
|
||||
self,
|
||||
separators: List[str],
|
||||
max_len: int = 512,
|
||||
chunk_overlap: int = 0,
|
||||
):
|
||||
super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
|
||||
if not separators:
|
||||
raise ValueError("separators must be a non-empty list")
|
||||
self.separators = separators
|
||||
|
||||
def split(self, text: str) -> List[str]:
|
||||
text = self.preprocess(text)
|
||||
for sep in self.separators:
|
||||
parts = text.split(sep)
|
||||
|
||||
valid_parts = [p.strip() for p in parts if p.strip()]
|
||||
if len(valid_parts) > 1:
|
||||
return self.postprocess(valid_parts)
|
||||
return [text]
|
||||
|
||||
|
||||
class SemanticTextSplitter(BaseTextSplitter):
|
||||
|
||||
DEFAULT_PATTERN = r'(?<=[。!?!?])(?=(?:[^"\'‘’“”]*["\'‘’“”][^"\'‘’“”]*["\'‘’“”])*[^"\'‘’“”]*$)'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_func: Callable[[List[str]], List[Tensor]],
|
||||
pattern: Optional[str] = None,
|
||||
max_len: int = 512,
|
||||
chunk_overlap: int = 0,
|
||||
):
|
||||
super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
|
||||
if not callable(embedding_func):
|
||||
raise TypeError("embedding_func must be callable")
|
||||
self.embedding_func = embedding_func
|
||||
self.pattern = pattern or SemanticTextSplitter.DEFAULT_PATTERN
|
||||
|
||||
def split(
|
||||
self,
|
||||
text: str,
|
||||
threshold: float = 0.5,
|
||||
window_size: int = 1,
|
||||
) -> List[str]:
|
||||
text = self.preprocess(text)
|
||||
sentences = [s.strip() for s in re.split(self.pattern, text) if s.strip()]
|
||||
|
||||
if len(sentences) <= 1:
|
||||
return self.postprocess(sentences)
|
||||
|
||||
try:
|
||||
sentence_embs = self.embedding_func(sentences)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Embedding generation failed: {e}")
|
||||
|
||||
if len(sentence_embs) != len(sentences):
|
||||
raise ValueError("Embedding function must return one vector per sentence")
|
||||
|
||||
chunks = []
|
||||
emb_tensor = torch.stack(sentence_embs) # shape: [N, D]
|
||||
current_chunk: List[str] = [sentences[0]]
|
||||
|
||||
for i in range(1, len(sentences)):
|
||||
start_prev = max(0, i - window_size)
|
||||
end_prev = i
|
||||
start_next = i
|
||||
end_next = min(len(sentences), i + window_size)
|
||||
|
||||
prev_window_emb = emb_tensor[start_prev:end_prev].mean(dim=0)
|
||||
next_window_emb = emb_tensor[start_next:end_next].mean(dim=0)
|
||||
|
||||
similarity = F.cosine_similarity(
|
||||
prev_window_emb.unsqueeze(0),
|
||||
next_window_emb.unsqueeze(0),
|
||||
dim=1
|
||||
).item()
|
||||
|
||||
dynamic_threshold = max(threshold * (1 - 0.03 * (end_next - start_prev)), 0.2)
|
||||
|
||||
if similarity < dynamic_threshold:
|
||||
chunks.append(" ".join(current_chunk))
|
||||
overlap_start = max(0, len(current_chunk) - self.chunk_overlap)
|
||||
current_chunk = current_chunk[overlap_start:]
|
||||
current_chunk.append(sentences[i])
|
||||
else:
|
||||
current_chunk.append(sentences[i])
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(" ".join(current_chunk))
|
||||
|
||||
return self.postprocess(chunks)
|
||||
|
|
@ -54,7 +54,7 @@ def test_env(request: pytest.FixtureRequest):
|
|||
def test_model_parameter(test_env):
|
||||
save_dir = os.path.join(test_env["test_dir"], "save")
|
||||
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])
|
||||
model_param.save(save_dir)
|
||||
ModelParameter.save(model_param, save_dir)
|
||||
|
||||
assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
|
||||
assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))
|
||||
|
|
|
|||
|
|
@ -1,45 +1,10 @@
|
|||
import torch
|
||||
import json
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
from khaosz import Khaosz
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def batch_generate(
|
||||
model: Khaosz,
|
||||
query: List[str],
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
batch_size: int,
|
||||
) -> List:
|
||||
assert batch_size > 0
|
||||
sorted_query = sorted(query, key=lambda x: len(x), reverse=True)
|
||||
original_indices = {query: idx for idx, query in enumerate(query)}
|
||||
|
||||
responses = [None] * len(query)
|
||||
total_batches = (len(sorted_query) + batch_size - 1) // batch_size
|
||||
|
||||
for i in tqdm(range(0, total_batches * batch_size, batch_size), desc="Generating responses"):
|
||||
batch_query = sorted_query[i: min(i + batch_size, len(query))]
|
||||
if not isinstance(batch_query, list):
|
||||
batch_query = [batch_query]
|
||||
|
||||
batch_responses = model.batch_generate(
|
||||
query=batch_query,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p
|
||||
)
|
||||
|
||||
for query, response in zip(batch_query, batch_responses):
|
||||
original_idx = original_indices[query]
|
||||
responses[original_idx] = response
|
||||
|
||||
return responses
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
from khaosz.inference.generator import BatchGenerator, GenerationRequest
|
||||
from khaosz.inference.core import disable_random_init
|
||||
|
||||
|
||||
def processor(
|
||||
|
|
@ -53,24 +18,31 @@ def processor(
|
|||
question_key: str,
|
||||
response_key: str,
|
||||
):
|
||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
||||
|
||||
with disable_random_init():
|
||||
param = ModelParameter.load(model_dir)
|
||||
|
||||
param.to(device='cuda', dtype=torch.bfloat16)
|
||||
generator = BatchGenerator(param)
|
||||
|
||||
with open(input_json_file, "r", encoding='utf-8') as f:
|
||||
input_data = [json.loads(line) for line in f]
|
||||
query = [item[question_key] for item in input_data]
|
||||
|
||||
responses = batch_generate(
|
||||
model=model,
|
||||
query=query,
|
||||
|
||||
queries = [item[question_key] for item in input_data]
|
||||
|
||||
request = GenerationRequest(
|
||||
query=queries,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
batch_size=batch_size
|
||||
top_k=top_k,
|
||||
max_len=param.config.max_len,
|
||||
history=None,
|
||||
system_prompt=None,
|
||||
)
|
||||
|
||||
# Write output in JSONL format
|
||||
|
||||
responses = generator.generate(request)
|
||||
|
||||
with open(output_json_file, "w", encoding='utf-8') as f:
|
||||
for query, response in zip(query, responses):
|
||||
for query, response in zip(queries, responses):
|
||||
output_item = {question_key: query, response_key: response}
|
||||
f.write(json.dumps(output_item, ensure_ascii=False) + '\n')
|
||||
|
||||
|
|
@ -89,4 +61,6 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.")
|
||||
|
||||
args = parser.parse_args()
|
||||
processor(**vars(args))
|
||||
|
||||
with torch.inference_mode():
|
||||
processor(**vars(args))
|
||||
|
|
@ -6,7 +6,9 @@ import argparse
|
|||
import tqdm
|
||||
|
||||
from torch import Tensor
|
||||
from khaosz import Khaosz
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
from khaosz.inference.core import disable_random_init
|
||||
|
||||
|
||||
def compute_perplexity(
|
||||
model: nn.Module,
|
||||
|
|
@ -45,22 +47,23 @@ def process_file(
|
|||
batch_size: int,
|
||||
text_key: str
|
||||
):
|
||||
model = Khaosz(model_dir).to(device="cuda", dtype=torch.bfloat16)
|
||||
tokenizer = model.parameter.tokenizer
|
||||
with disable_random_init():
|
||||
param = ModelParameter.load(model_dir)
|
||||
|
||||
param.to(device='cuda', dtype=torch.bfloat16)
|
||||
model = param.model
|
||||
tokenizer = param.tokenizer
|
||||
|
||||
with open(input_file, "r", encoding='utf-8') as f:
|
||||
input_data = [json.loads(line) for line in f]
|
||||
|
||||
texts = [item[text_key] for item in input_data]
|
||||
encoded_texts = [tokenizer.encode(text) for text in texts]
|
||||
|
||||
output_data = []
|
||||
|
||||
for i in tqdm(range(0, len(encoded_texts), batch_size), desc="Computing perplexity"):
|
||||
batch_encoded = encoded_texts[i:i + batch_size]
|
||||
batch_texts = texts[i:i + batch_size]
|
||||
|
||||
# Pad sequences to the same length (left padding)
|
||||
max_len = max(len(seq) for seq in batch_encoded)
|
||||
padded_ids = []
|
||||
masks = []
|
||||
|
|
@ -74,10 +77,7 @@ def process_file(
|
|||
|
||||
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
|
||||
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
|
||||
|
||||
# Compute perplexity
|
||||
with torch.inference_mode():
|
||||
perplexity = compute_perplexity(model.parameter.model, input_ids, input_mask)
|
||||
perplexity = compute_perplexity(model, input_ids, input_mask)
|
||||
|
||||
for text, ppl in zip(batch_texts, perplexity):
|
||||
output_data.append({text_key: text, "ppl": float(ppl.item())})
|
||||
|
|
@ -87,16 +87,14 @@ def process_file(
|
|||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
|
||||
def main():
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
||||
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
|
||||
parser.add_argument("--input_file", type=str, required=True, help="Path to the input file.")
|
||||
parser.add_argument("--output_file", type=str, required=True, help="Path to the output file.")
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation.")
|
||||
parser.add_argument("--text_key", type=str, default="text", help="Key for the text field in the input data.")
|
||||
|
||||
args = parser.parse_args()
|
||||
process_file(**vars(args))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
with torch.inference_mode():
|
||||
process_file(**vars(args))
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ def parse_args() -> argparse.Namespace:
|
|||
|
||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||
|
||||
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
|
||||
parser.add_argument("--train_type", type=str, required=True, choices=["seq", "sft", "dpo"], help="Train type.")
|
||||
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
|
||||
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
|
||||
|
||||
|
|
@ -67,18 +67,14 @@ def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler
|
|||
return SchedulerFactory.load(optimizer, **kwargs)
|
||||
|
||||
def prepare_checkpoint(model: nn.Module) -> dict:
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
return state_dict
|
||||
return model.module.state_dict()
|
||||
|
||||
|
||||
def train(
|
||||
train_type: str,
|
||||
param_path: str,
|
||||
data_root_path: str,
|
||||
max_lr: int,
|
||||
max_lr: float,
|
||||
n_epoch: int,
|
||||
batch_size: int,
|
||||
start_epoch: int,
|
||||
|
|
@ -104,8 +100,7 @@ def train(
|
|||
assert train_type in ["seq", "sft", "dpo"]
|
||||
assert os.path.exists(param_path)
|
||||
|
||||
parameter = ModelParameter()
|
||||
parameter.load(param_path)
|
||||
parameter = ModelParameter.load(param_path)
|
||||
|
||||
if window_size is None:
|
||||
window_size = parameter.config.max_len
|
||||
|
|
|
|||
Loading…
Reference in New Issue