feat: 增加推理部分工厂模式

This commit is contained in:
ViperEkura 2026-03-30 00:55:15 +08:00
parent 980299cd54
commit c01791ff54
17 changed files with 227 additions and 559 deletions

View File

@ -1,27 +1,35 @@
import os
import torch
from khaosz import Khaosz
from khaosz.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init
from khaosz.inference.generator import LoopGenerator, GenerationRequest
PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
def generate_text():
model_dir = os.path.join(PROJECT_ROOT, "params")
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
with disable_random_init():
model_dir = os.path.join(PROJECT_ROOT, "params")
param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16)
query = input(">> ")
response = model.text_generate(
request = GenerationRequest(
query=query,
temperature=0.8,
top_p=0.95,
top_k=50
top_k=50,
max_len=param.config.max_len,
history=None,
system_prompt=None,
)
generator = LoopGenerator(param)
response = generator.generate(request)
print(response)
if __name__ == "__main__":
generate_text()

View File

@ -1,22 +1,31 @@
import os
import torch
from khaosz import Khaosz
from khaosz.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init
from khaosz.inference.generator import BatchGenerator, GenerationRequest
PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
def batch_generate():
model_dir = os.path.join(PROJECT_ROOT, "params")
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
with disable_random_init():
model_dir = os.path.join(PROJECT_ROOT, "params")
param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16)
generator = BatchGenerator(param)
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
responses = model.batch_generate(
request = GenerationRequest(
query=inputs,
temperature=0.8,
top_p=0.95,
top_k=50
top_k=50,
max_len=param.config.max_len,
history=None,
system_prompt=None,
)
responses = generator.generate(request)
for q, r in zip(inputs, responses):
print((q, r))

View File

@ -1,43 +0,0 @@
import os
import torch
from khaosz import Khaosz, SemanticTextSplitter, Retriever
PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
if __name__ == "__main__":
model_dir = os.path.join(PROJECT_ROOT, "params")
context_path = os.path.join(PROJECT_ROOT, "README.md")
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
spliter = SemanticTextSplitter(model.encode)
retriever = Retriever()
text = open(context_path, "r", encoding="utf-8").read()
res = spliter.split(text, threshold=0.8, window_size=1)
# print(("\n" + "+"*100 + "\n").join(res))
res_embs = model.encode(res)
for sentence, emb in zip(res, res_embs):
retriever.add_vector(sentence, emb)
retrive_top_k = 5
query = "作者设计了一个怎样的模型"
emb_query = model.encode(query)
retrieved = retriever.retrieve(emb_query, retrive_top_k)
retrieved_content = "\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)])
retrive_response = model.retrieve_generate(
retrieved=retrieved_content,
query=query,
temperature=0.8,
top_p=0.95,
top_k=50
)
print("retrieve content:")
print(retrieved_content)
print("\n\nretrive generate:")
print(retrive_response)

View File

@ -1,14 +1,21 @@
import os
import torch
from khaosz import Khaosz
from khaosz.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init
from khaosz.inference.generator import StreamGenerator, GenerationRequest
PROJECT_ROOT = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
def chat():
model_dir = os.path.join(PROJECT_ROOT, "params")
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
with disable_random_init():
model_dir = os.path.join(PROJECT_ROOT, "params")
param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16)
generator = StreamGenerator(param)
history = []
while True:
@ -16,16 +23,26 @@ def chat():
if query == "!exit":
break
response_size = 0
for response, history in model.stream_generate(
request = GenerationRequest(
query=query,
history=history,
temperature=0.8,
top_p=0.95,
top_k=50
):
top_k=50,
max_len=param.config.max_len,
history=history,
system_prompt=None,
)
response_size = 0
full_response = ""
for response in generator.generate(request):
# response is the cumulative response up to current token
print(response[response_size:], end="", flush=True)
response_size = len(response)
full_response = response
# After generation, update history
history.append((query, full_response.strip()))
if __name__ == "__main__":

View File

@ -1,17 +1,11 @@
__version__ = "1.3.2"
__author__ = "ViperEkura"
from khaosz.api import Khaosz
from khaosz.config import (
ModelConfig,
TrainConfig,
)
from khaosz.model.transformer import Transformer
from khaosz.utils.retriever import Retriever
from khaosz.utils.splitter import (
SemanticTextSplitter,
PriorityTextSplitter
)
from khaosz.data import (
DatasetLoader,
BpeTokenizer
@ -22,8 +16,8 @@ from khaosz.inference.generator import (
StreamGenerator,
BatchGenerator,
EmbeddingEncoder,
GeneratorFactory
)
from khaosz.trainer import (
Trainer,
StrategyFactory,
@ -31,14 +25,8 @@ from khaosz.trainer import (
)
__all__ = [
"Khaosz",
"Transformer",
"Retriever",
"SemanticTextSplitter",
"PriorityTextSplitter",
"ModelConfig",
"TrainConfig",
@ -50,6 +38,7 @@ __all__ = [
"StreamGenerator",
"BatchGenerator",
"EmbeddingEncoder",
"GeneratorFactory",
"Trainer",
"StrategyFactory",

View File

@ -1,138 +0,0 @@
from torch import nn
from torch import Tensor
from contextlib import contextmanager
from typing import List, Tuple, Generator, Union
from khaosz.inference.generator import (
GenerationRequest,
LoopGenerator,
StreamGenerator,
BatchGenerator,
EmbeddingEncoder
)
from khaosz.config.param_config import ModelParameter
@contextmanager
def disable_random_init():
init_functions = [
'xavier_normal_', 'xavier_uniform_',
'kaiming_normal_', 'kaiming_uniform_',
'zeros_', 'ones_', 'constant_',
'normal_', 'uniform_'
]
original_funcs = {}
for name in init_functions:
if hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try:
yield
finally:
for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
class Khaosz:
def __init__(self, model_dir: str):
with disable_random_init():
self.parameter = ModelParameter()
self.parameter.load(model_dir)
def to(self, *args, **kwargs):
self.parameter.to(*args, **kwargs)
return self
def generate(
self,
query: str,
history: List[Tuple[str, str]]=None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> str:
generator = LoopGenerator(self.parameter)
return generator.generate(
GenerationRequest(
top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
history=history,
build_prompt=True
))
def batch_generate(
self,
query: List[str],
history: List[Tuple[str, str]]=None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> List[str]:
generator = BatchGenerator(self.parameter)
return generator.generate(
GenerationRequest(
top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
history=history,
build_prompt=True
))
def stream_generate(
self,
query: str,
history: List[Tuple[str, str]]=None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
stream_generator = StreamGenerator(self.parameter)
return stream_generator.generate(
GenerationRequest(
top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
history=history,
build_prompt=True
))
def retrieve_generate(
self,
retrieved,
query: str,
history: List[Tuple[str, str]] = None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> str:
generator = LoopGenerator(self.parameter)
return generator.generate(
GenerationRequest(
top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
history=history,
system_prompt=retrieved,
build_prompt=True
))
def text_generate(
self,
query: str,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> str:
generator = LoopGenerator(self.parameter)
return generator.generate(
GenerationRequest(
top_k, top_p, temperature,
self.parameter.config.max_len,
query=query,
build_prompt=False
))
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
encoder = EmbeddingEncoder(self.parameter)
return encoder.encode(sentence)

View File

@ -72,9 +72,12 @@ class BaseModelIO:
class ModelParameter(BaseModelIO):
"""Container for model parameters with serialization capabilities."""
def save(self, save_dir: Union[str, Path]):
self.save_components(save_dir)
@classmethod
def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]):
instance.save_components(save_dir)
def load(self, load_dir: Union[str, Path]) -> "ModelParameter":
return self.load_components(load_dir)
@classmethod
def load(cls, load_dir: Union[str, Path]) -> "ModelParameter":
instance = cls()
return instance.load_components(load_dir)

View File

@ -1,4 +1,5 @@
from khaosz.inference.core import (
disable_random_init,
GeneratorCore,
EmbeddingEncoderCore,
KVCacheManager,
@ -10,9 +11,11 @@ from khaosz.inference.generator import (
StreamGenerator,
BatchGenerator,
EmbeddingEncoder,
GeneratorFactory
)
__all__ = [
"disable_random_init",
"GeneratorCore",
"EmbeddingEncoderCore",
"KVCacheManager",
@ -22,4 +25,5 @@ __all__ = [
"StreamGenerator",
"BatchGenerator",
"EmbeddingEncoder",
"GeneratorFactory"
]

View File

@ -1,5 +1,8 @@
import torch
import torch.nn as nn
from torch import Tensor
from contextlib import contextmanager
from typing import Any, Callable, List, Tuple, Union, Optional, Self
from khaosz.config import ModelParameter, ModelConfig
@ -54,6 +57,26 @@ def apply_sampling_strategies(
return logits
@contextmanager
def disable_random_init():
init_functions = [
'xavier_normal_', 'xavier_uniform_',
'kaiming_normal_', 'kaiming_uniform_',
'zeros_', 'ones_', 'constant_',
'normal_', 'uniform_'
]
original_funcs = {}
for name in init_functions:
if hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try:
yield
finally:
for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
class GeneratorCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
@ -82,10 +105,6 @@ class GeneratorCore:
return next_token_id, cache_increase
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
def generate_loop(
self,
input_ids: Tensor,
@ -116,6 +135,10 @@ class GeneratorCore:
return ids
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class EmbeddingEncoderCore:
def __init__(self, parameter: ModelParameter):

View File

@ -10,32 +10,36 @@ HistoryType = List[Tuple[str, str]]
def build_prompt(
query: str,
init_prompt: Optional[str] = None,
history: Optional[List[Tuple[str, str]]] = None
) -> str:
system_prompt: Optional[str] = None,
history: Optional[HistoryType] = None
) -> str:
"""
Build prompt in ChatML format for query and history
Build prompt in ChatML format for query and history.
Args:
query(str): query string
history(Optional[List[Tuple[str, str]]]): history list of query and response
query (str): query string.
system_prompt (Optional[str]): system prompt string.
history (Optional[HistoryType]): history list of query and response.
Returns:
str: prompt string in ChatML format
str: prompt string in ChatML format.
"""
prompt = f"<|im_start|>system\n{init_prompt}<|im_end|>\n" if init_prompt else ""
result = ""
if system_prompt:
result += f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
# (convert tuple format to ChatML)
if history:
for user_msg, assistant_msg in history:
prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
result += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
result += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
prompt += f"<|im_start|>user\n{query}<|im_end|>\n"
prompt += "<|im_start|>assistant\n"
result += f"<|im_start|>user\n{query}<|im_end|>\n"
result += "<|im_start|>assistant\n"
return result
return prompt
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
"""
@ -59,8 +63,21 @@ def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]
return new_ids_list, max_ids_len
@dataclass
class GenerationRequest:
"""
Request parameters for text generation.
Attributes:
top_k: Top-k sampling parameter.
top_p: Top-p (nucleus) sampling parameter.
temperature: Sampling temperature.
max_len: Maximum generation length.
query: Input query (string or list of strings for batch).
history: Conversation history.
system_prompt: System prompt for the conversation.
"""
top_k: int
top_p: float
temperature: float
@ -70,8 +87,6 @@ class GenerationRequest:
history: Optional[Union[HistoryType, List[HistoryType]]] = None
system_prompt: Optional[str] = None
build_prompt: bool = True
def __post_init__(self):
if not isinstance(self.top_k, int) or self.top_k < 0:
raise ValueError("top_k must be a non-negative integer")
@ -89,19 +104,21 @@ class LoopGenerator(GeneratorCore):
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, 1, device=device)
input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query
ids = self.tokenizer.encode(input_args)
prompt = build_prompt(request.query, request.history)
ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
kv_caches = cache_manager.get_kvcache()
ids = self.generate_loop(
input_ids, ids, request.temperature, request.top_k, request.top_p,
input_ids,
ids,
request.temperature,
request.top_k,
request.top_p,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
response = self.tokenizer.decode(ids[start_cache_pos:])
@ -112,16 +129,12 @@ class StreamGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(self, request: GenerationRequest) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
if request.history is None:
request.history = []
def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(self.config, 1, device=device)
input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query
ids = self.tokenizer.encode(input_args)
prompt = build_prompt(request.query, request.history)
ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids)
@ -141,10 +154,10 @@ class StreamGenerator(GeneratorCore):
cur_cache_pos += cache_increase
response = self.tokenizer.decode(ids[start_cache_pos:])
yield response, request.history + [(request.query, response)]
yield response
if next_token_id.item() in self.tokenizer.stop_ids:
yield response + "\n", request.history + [(request.query, response)]
yield response + "\n"
break
@ -218,3 +231,35 @@ class EmbeddingEncoder(EmbeddingEncoderCore):
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
return super().encode(sentence)
class GeneratorFactory:
"""Factory class for creating appropriate generator instances based on request features."""
@staticmethod
def create_generator(parameter: ModelParameter, request: GenerationRequest):
"""
Create a generator based on the characteristics of GenerationRequest.
Args:
parameter: Model parameters
request: Generation request
Returns:
Subclass instance of GeneratorCore
"""
# Streaming generation detection: check stream field
if request.stream:
return StreamGenerator(parameter)
# Batch generation detection: query is a list
if isinstance(request.query, list):
return BatchGenerator(parameter)
# Default return LoopGenerator
return LoopGenerator(parameter)
@staticmethod
def create_encoder(parameter: ModelParameter):
"""Create an EmbeddingEncoder instance"""
return EmbeddingEncoder(parameter)

View File

@ -1 +0,0 @@
# init file

View File

@ -1,88 +0,0 @@
import torch
import sqlite3
import numpy as np
from torch import Tensor
from typing import Dict, List, Tuple
class Retriever:
def __init__(self, db_path=None):
self.data: Dict[str, Tensor] = {}
self.embedding_cache: Tensor = None
self.is_caculated: bool = False
if db_path is not None:
self.load(db_path)
def retrieve(self, query: Tensor, top_k: int) -> List[Tuple[str, float]]:
if not self.data:
return []
query = query.flatten().unsqueeze(1) # [dim, 1]
norm_embeddings = self._embeddings.to(
device=query.device,
dtype=query.dtype
) # [n_vectors, dim]
sim_scores = torch.matmul(norm_embeddings, query).squeeze() # [n_vectors]
top_k = min(top_k, len(self.data))
indices = sim_scores.topk(top_k).indices
keys = list(self.data.keys())
return [(keys[i], sim_scores[i].item()) for i in indices]
def add_vector(self, key: str, vector_data: Tensor):
self.is_caculated = False
self.data[key] = vector_data.flatten().float().cpu()
def delete_vector(self, key: str):
self.is_caculated = False
self.data.pop(key, None)
def save(self, db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
self._init_db(cursor)
cursor.execute('DELETE FROM vectors')
for item, vec in self.data.items():
vec_bytes = vec.numpy().tobytes()
cursor.execute('INSERT OR REPLACE INTO vectors (key, vector) VALUES (?, ?)',
(item, vec_bytes))
conn.commit()
conn.close()
def load(self, db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
self._init_db(cursor)
cursor.execute('SELECT key, vector FROM vectors')
rows = cursor.fetchall()
self.data = {}
for row in rows:
key, vec_bytes = row
vec_numpy = np.frombuffer(vec_bytes, dtype=np.float32).copy()
vec = torch.from_numpy(vec_numpy)
self.data[key] = vec
conn.close()
def _init_db(self,cursor: sqlite3.Cursor):
# Create table if not exists (in case loading from a new database)
cursor.execute('''
CREATE TABLE IF NOT EXISTS vectors (
id INTEGER PRIMARY KEY AUTOINCREMENT,
key TEXT UNIQUE NOT NULL,
vector BLOB NOT NULL
)''')
@property
def _embeddings(self) -> Tensor:
if not self.is_caculated:
embeddings = torch.stack(list(self.data.values()))
norm_embeddings = embeddings / torch.norm(embeddings, dim=-1, keepdim=True)
self.embedding_cache = norm_embeddings
return self.embedding_cache

View File

@ -1,127 +0,0 @@
import re
import torch
import torch.nn.functional as F
from abc import ABC, abstractmethod
from torch import Tensor
from typing import List, Callable, Optional
class BaseTextSplitter(ABC):
def __init__(
self,
max_len: int = 512,
chunk_overlap: int = 0,
):
if max_len <= 0:
raise ValueError("max_len must be > 0")
if chunk_overlap < 0:
raise ValueError("chunk_overlap must be >= 0")
self.max_len = max_len
self.chunk_overlap = chunk_overlap
@abstractmethod
def split(self, text: str, **kwargs) -> List[str]:
raise NotImplementedError
def preprocess(self, text: str) -> str:
return text.strip()
def postprocess(self, chunks: List[str]) -> List[str]:
return [chunk.strip() for chunk in chunks if chunk.strip()]
class PriorityTextSplitter(BaseTextSplitter):
def __init__(
self,
separators: List[str],
max_len: int = 512,
chunk_overlap: int = 0,
):
super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
if not separators:
raise ValueError("separators must be a non-empty list")
self.separators = separators
def split(self, text: str) -> List[str]:
text = self.preprocess(text)
for sep in self.separators:
parts = text.split(sep)
valid_parts = [p.strip() for p in parts if p.strip()]
if len(valid_parts) > 1:
return self.postprocess(valid_parts)
return [text]
class SemanticTextSplitter(BaseTextSplitter):
DEFAULT_PATTERN = r'(?<=[。!?!?])(?=(?:[^"\'‘’“”]*["\'‘’“”][^"\'‘’“”]*["\'‘’“”])*[^"\'‘’“”]*$)'
def __init__(
self,
embedding_func: Callable[[List[str]], List[Tensor]],
pattern: Optional[str] = None,
max_len: int = 512,
chunk_overlap: int = 0,
):
super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
if not callable(embedding_func):
raise TypeError("embedding_func must be callable")
self.embedding_func = embedding_func
self.pattern = pattern or SemanticTextSplitter.DEFAULT_PATTERN
def split(
self,
text: str,
threshold: float = 0.5,
window_size: int = 1,
) -> List[str]:
text = self.preprocess(text)
sentences = [s.strip() for s in re.split(self.pattern, text) if s.strip()]
if len(sentences) <= 1:
return self.postprocess(sentences)
try:
sentence_embs = self.embedding_func(sentences)
except Exception as e:
raise RuntimeError(f"Embedding generation failed: {e}")
if len(sentence_embs) != len(sentences):
raise ValueError("Embedding function must return one vector per sentence")
chunks = []
emb_tensor = torch.stack(sentence_embs) # shape: [N, D]
current_chunk: List[str] = [sentences[0]]
for i in range(1, len(sentences)):
start_prev = max(0, i - window_size)
end_prev = i
start_next = i
end_next = min(len(sentences), i + window_size)
prev_window_emb = emb_tensor[start_prev:end_prev].mean(dim=0)
next_window_emb = emb_tensor[start_next:end_next].mean(dim=0)
similarity = F.cosine_similarity(
prev_window_emb.unsqueeze(0),
next_window_emb.unsqueeze(0),
dim=1
).item()
dynamic_threshold = max(threshold * (1 - 0.03 * (end_next - start_prev)), 0.2)
if similarity < dynamic_threshold:
chunks.append(" ".join(current_chunk))
overlap_start = max(0, len(current_chunk) - self.chunk_overlap)
current_chunk = current_chunk[overlap_start:]
current_chunk.append(sentences[i])
else:
current_chunk.append(sentences[i])
if current_chunk:
chunks.append(" ".join(current_chunk))
return self.postprocess(chunks)

View File

@ -54,7 +54,7 @@ def test_env(request: pytest.FixtureRequest):
def test_model_parameter(test_env):
save_dir = os.path.join(test_env["test_dir"], "save")
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])
model_param.save(save_dir)
ModelParameter.save(model_param, save_dir)
assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))

View File

@ -1,45 +1,10 @@
import torch
import json
import torch
import argparse
from khaosz import Khaosz
from typing import List
from tqdm import tqdm
def batch_generate(
model: Khaosz,
query: List[str],
temperature: float,
top_k: int,
top_p: float,
batch_size: int,
) -> List:
assert batch_size > 0
sorted_query = sorted(query, key=lambda x: len(x), reverse=True)
original_indices = {query: idx for idx, query in enumerate(query)}
responses = [None] * len(query)
total_batches = (len(sorted_query) + batch_size - 1) // batch_size
for i in tqdm(range(0, total_batches * batch_size, batch_size), desc="Generating responses"):
batch_query = sorted_query[i: min(i + batch_size, len(query))]
if not isinstance(batch_query, list):
batch_query = [batch_query]
batch_responses = model.batch_generate(
query=batch_query,
temperature=temperature,
top_k=top_k,
top_p=top_p
)
for query, response in zip(batch_query, batch_responses):
original_idx = original_indices[query]
responses[original_idx] = response
return responses
from khaosz.config.param_config import ModelParameter
from khaosz.inference.generator import BatchGenerator, GenerationRequest
from khaosz.inference.core import disable_random_init
def processor(
@ -53,24 +18,31 @@ def processor(
question_key: str,
response_key: str,
):
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
with disable_random_init():
param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16)
generator = BatchGenerator(param)
with open(input_json_file, "r", encoding='utf-8') as f:
input_data = [json.loads(line) for line in f]
query = [item[question_key] for item in input_data]
responses = batch_generate(
model=model,
query=query,
queries = [item[question_key] for item in input_data]
request = GenerationRequest(
query=queries,
temperature=temperature,
top_k=top_k,
top_p=top_p,
batch_size=batch_size
top_k=top_k,
max_len=param.config.max_len,
history=None,
system_prompt=None,
)
# Write output in JSONL format
responses = generator.generate(request)
with open(output_json_file, "w", encoding='utf-8') as f:
for query, response in zip(query, responses):
for query, response in zip(queries, responses):
output_item = {question_key: query, response_key: response}
f.write(json.dumps(output_item, ensure_ascii=False) + '\n')
@ -89,4 +61,6 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.")
args = parser.parse_args()
processor(**vars(args))
with torch.inference_mode():
processor(**vars(args))

View File

@ -6,7 +6,9 @@ import argparse
import tqdm
from torch import Tensor
from khaosz import Khaosz
from khaosz.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init
def compute_perplexity(
model: nn.Module,
@ -45,22 +47,23 @@ def process_file(
batch_size: int,
text_key: str
):
model = Khaosz(model_dir).to(device="cuda", dtype=torch.bfloat16)
tokenizer = model.parameter.tokenizer
with disable_random_init():
param = ModelParameter.load(model_dir)
param.to(device='cuda', dtype=torch.bfloat16)
model = param.model
tokenizer = param.tokenizer
with open(input_file, "r", encoding='utf-8') as f:
input_data = [json.loads(line) for line in f]
texts = [item[text_key] for item in input_data]
encoded_texts = [tokenizer.encode(text) for text in texts]
output_data = []
for i in tqdm(range(0, len(encoded_texts), batch_size), desc="Computing perplexity"):
batch_encoded = encoded_texts[i:i + batch_size]
batch_texts = texts[i:i + batch_size]
# Pad sequences to the same length (left padding)
max_len = max(len(seq) for seq in batch_encoded)
padded_ids = []
masks = []
@ -74,10 +77,7 @@ def process_file(
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
# Compute perplexity
with torch.inference_mode():
perplexity = compute_perplexity(model.parameter.model, input_ids, input_mask)
perplexity = compute_perplexity(model, input_ids, input_mask)
for text, ppl in zip(batch_texts, perplexity):
output_data.append({text_key: text, "ppl": float(ppl.item())})
@ -87,16 +87,14 @@ def process_file(
f.write(json.dumps(item, ensure_ascii=False) + '\n')
def main():
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
parser.add_argument("--input_file", type=str, required=True, help="Path to the input file.")
parser.add_argument("--output_file", type=str, required=True, help="Path to the output file.")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation.")
parser.add_argument("--text_key", type=str, default="text", help="Key for the text field in the input data.")
args = parser.parse_args()
process_file(**vars(args))
if __name__ == "__main__":
main()
with torch.inference_mode():
process_file(**vars(args))

View File

@ -16,7 +16,7 @@ def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train the Transformer model.")
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
parser.add_argument("--train_type", type=str, required=True, choices=["seq", "sft", "dpo"], help="Train type.")
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
@ -67,18 +67,14 @@ def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler
return SchedulerFactory.load(optimizer, **kwargs)
def prepare_checkpoint(model: nn.Module) -> dict:
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
return state_dict
return model.module.state_dict()
def train(
train_type: str,
param_path: str,
data_root_path: str,
max_lr: int,
max_lr: float,
n_epoch: int,
batch_size: int,
start_epoch: int,
@ -104,8 +100,7 @@ def train(
assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path)
parameter = ModelParameter()
parameter.load(param_path)
parameter = ModelParameter.load(param_path)
if window_size is None:
window_size = parameter.config.max_len