feat: 增加推理部分工厂模式
This commit is contained in:
parent
980299cd54
commit
c01791ff54
|
|
@ -1,27 +1,35 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from khaosz import Khaosz
|
from khaosz.config.param_config import ModelParameter
|
||||||
|
from khaosz.inference.core import disable_random_init
|
||||||
|
from khaosz.inference.generator import LoopGenerator, GenerationRequest
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
PROJECT_ROOT = os.path.dirname(
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
def generate_text():
|
def generate_text():
|
||||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
|
||||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
with disable_random_init():
|
||||||
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
|
param = ModelParameter.load(model_dir)
|
||||||
|
|
||||||
|
param.to(device='cuda', dtype=torch.bfloat16)
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
|
|
||||||
response = model.text_generate(
|
request = GenerationRequest(
|
||||||
query=query,
|
query=query,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50
|
top_k=50,
|
||||||
|
max_len=param.config.max_len,
|
||||||
|
history=None,
|
||||||
|
system_prompt=None,
|
||||||
)
|
)
|
||||||
|
generator = LoopGenerator(param)
|
||||||
|
response = generator.generate(request)
|
||||||
|
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
generate_text()
|
generate_text()
|
||||||
|
|
@ -1,22 +1,31 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from khaosz import Khaosz
|
from khaosz.config.param_config import ModelParameter
|
||||||
|
from khaosz.inference.core import disable_random_init
|
||||||
|
from khaosz.inference.generator import BatchGenerator, GenerationRequest
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
PROJECT_ROOT = os.path.dirname(
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
def batch_generate():
|
def batch_generate():
|
||||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
with disable_random_init():
|
||||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
|
param = ModelParameter.load(model_dir)
|
||||||
|
|
||||||
|
param.to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
generator = BatchGenerator(param)
|
||||||
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
|
inputs = ["你好", "请问什么是人工智能", "今天天气如何", "我感到焦虑, 请问我应该怎么办", "请问什么是显卡"]
|
||||||
|
|
||||||
responses = model.batch_generate(
|
request = GenerationRequest(
|
||||||
query=inputs,
|
query=inputs,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50
|
top_k=50,
|
||||||
|
max_len=param.config.max_len,
|
||||||
|
history=None,
|
||||||
|
system_prompt=None,
|
||||||
)
|
)
|
||||||
|
responses = generator.generate(request)
|
||||||
|
|
||||||
for q, r in zip(inputs, responses):
|
for q, r in zip(inputs, responses):
|
||||||
print((q, r))
|
print((q, r))
|
||||||
|
|
|
||||||
|
|
@ -1,43 +0,0 @@
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
from khaosz import Khaosz, SemanticTextSplitter, Retriever
|
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
|
||||||
context_path = os.path.join(PROJECT_ROOT, "README.md")
|
|
||||||
|
|
||||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
|
||||||
spliter = SemanticTextSplitter(model.encode)
|
|
||||||
retriever = Retriever()
|
|
||||||
text = open(context_path, "r", encoding="utf-8").read()
|
|
||||||
|
|
||||||
res = spliter.split(text, threshold=0.8, window_size=1)
|
|
||||||
# print(("\n" + "+"*100 + "\n").join(res))
|
|
||||||
|
|
||||||
res_embs = model.encode(res)
|
|
||||||
for sentence, emb in zip(res, res_embs):
|
|
||||||
retriever.add_vector(sentence, emb)
|
|
||||||
|
|
||||||
retrive_top_k = 5
|
|
||||||
query = "作者设计了一个怎样的模型"
|
|
||||||
emb_query = model.encode(query)
|
|
||||||
retrieved = retriever.retrieve(emb_query, retrive_top_k)
|
|
||||||
retrieved_content = "\n".join([f"{idx + 1}. " + text for idx, (text, _) in enumerate(retrieved)])
|
|
||||||
|
|
||||||
retrive_response = model.retrieve_generate(
|
|
||||||
retrieved=retrieved_content,
|
|
||||||
query=query,
|
|
||||||
temperature=0.8,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=50
|
|
||||||
)
|
|
||||||
|
|
||||||
print("retrieve content:")
|
|
||||||
print(retrieved_content)
|
|
||||||
|
|
||||||
print("\n\nretrive generate:")
|
|
||||||
print(retrive_response)
|
|
||||||
|
|
@ -1,14 +1,21 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from khaosz import Khaosz
|
from khaosz.config.param_config import ModelParameter
|
||||||
|
from khaosz.inference.core import disable_random_init
|
||||||
|
from khaosz.inference.generator import StreamGenerator, GenerationRequest
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(
|
PROJECT_ROOT = os.path.dirname(
|
||||||
os.path.dirname(os.path.abspath(__file__)))
|
os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
def chat():
|
def chat():
|
||||||
model_dir = os.path.join(PROJECT_ROOT, "params")
|
|
||||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
with disable_random_init():
|
||||||
|
model_dir = os.path.join(PROJECT_ROOT, "params")
|
||||||
|
param = ModelParameter.load(model_dir)
|
||||||
|
|
||||||
|
param.to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
generator = StreamGenerator(param)
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -16,17 +23,27 @@ def chat():
|
||||||
if query == "!exit":
|
if query == "!exit":
|
||||||
break
|
break
|
||||||
|
|
||||||
response_size = 0
|
request = GenerationRequest(
|
||||||
for response, history in model.stream_generate(
|
query=query,
|
||||||
query=query,
|
|
||||||
history=history,
|
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50
|
top_k=50,
|
||||||
):
|
max_len=param.config.max_len,
|
||||||
|
history=history,
|
||||||
|
system_prompt=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
response_size = 0
|
||||||
|
full_response = ""
|
||||||
|
for response in generator.generate(request):
|
||||||
|
# response is the cumulative response up to current token
|
||||||
print(response[response_size:], end="", flush=True)
|
print(response[response_size:], end="", flush=True)
|
||||||
response_size = len(response)
|
response_size = len(response)
|
||||||
|
full_response = response
|
||||||
|
|
||||||
|
# After generation, update history
|
||||||
|
history.append((query, full_response.strip()))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
chat()
|
chat()
|
||||||
|
|
@ -1,17 +1,11 @@
|
||||||
__version__ = "1.3.2"
|
__version__ = "1.3.2"
|
||||||
__author__ = "ViperEkura"
|
__author__ = "ViperEkura"
|
||||||
|
|
||||||
from khaosz.api import Khaosz
|
|
||||||
from khaosz.config import (
|
from khaosz.config import (
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
TrainConfig,
|
TrainConfig,
|
||||||
)
|
)
|
||||||
from khaosz.model.transformer import Transformer
|
from khaosz.model.transformer import Transformer
|
||||||
from khaosz.utils.retriever import Retriever
|
|
||||||
from khaosz.utils.splitter import (
|
|
||||||
SemanticTextSplitter,
|
|
||||||
PriorityTextSplitter
|
|
||||||
)
|
|
||||||
from khaosz.data import (
|
from khaosz.data import (
|
||||||
DatasetLoader,
|
DatasetLoader,
|
||||||
BpeTokenizer
|
BpeTokenizer
|
||||||
|
|
@ -22,8 +16,8 @@ from khaosz.inference.generator import (
|
||||||
StreamGenerator,
|
StreamGenerator,
|
||||||
BatchGenerator,
|
BatchGenerator,
|
||||||
EmbeddingEncoder,
|
EmbeddingEncoder,
|
||||||
|
GeneratorFactory
|
||||||
)
|
)
|
||||||
|
|
||||||
from khaosz.trainer import (
|
from khaosz.trainer import (
|
||||||
Trainer,
|
Trainer,
|
||||||
StrategyFactory,
|
StrategyFactory,
|
||||||
|
|
@ -31,14 +25,8 @@ from khaosz.trainer import (
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Khaosz",
|
|
||||||
|
|
||||||
"Transformer",
|
"Transformer",
|
||||||
|
|
||||||
"Retriever",
|
|
||||||
"SemanticTextSplitter",
|
|
||||||
"PriorityTextSplitter",
|
|
||||||
|
|
||||||
"ModelConfig",
|
"ModelConfig",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
|
|
||||||
|
|
@ -50,6 +38,7 @@ __all__ = [
|
||||||
"StreamGenerator",
|
"StreamGenerator",
|
||||||
"BatchGenerator",
|
"BatchGenerator",
|
||||||
"EmbeddingEncoder",
|
"EmbeddingEncoder",
|
||||||
|
"GeneratorFactory",
|
||||||
|
|
||||||
"Trainer",
|
"Trainer",
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
|
|
|
||||||
138
khaosz/api.py
138
khaosz/api.py
|
|
@ -1,138 +0,0 @@
|
||||||
from torch import nn
|
|
||||||
from torch import Tensor
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import List, Tuple, Generator, Union
|
|
||||||
|
|
||||||
from khaosz.inference.generator import (
|
|
||||||
GenerationRequest,
|
|
||||||
LoopGenerator,
|
|
||||||
StreamGenerator,
|
|
||||||
BatchGenerator,
|
|
||||||
EmbeddingEncoder
|
|
||||||
)
|
|
||||||
from khaosz.config.param_config import ModelParameter
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def disable_random_init():
|
|
||||||
init_functions = [
|
|
||||||
'xavier_normal_', 'xavier_uniform_',
|
|
||||||
'kaiming_normal_', 'kaiming_uniform_',
|
|
||||||
'zeros_', 'ones_', 'constant_',
|
|
||||||
'normal_', 'uniform_'
|
|
||||||
]
|
|
||||||
original_funcs = {}
|
|
||||||
for name in init_functions:
|
|
||||||
if hasattr(nn.init, name):
|
|
||||||
original_funcs[name] = getattr(nn.init, name)
|
|
||||||
setattr(nn.init, name, lambda *args, **kwargs: None)
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
for name, orig_func in original_funcs.items():
|
|
||||||
setattr(nn.init, name, orig_func)
|
|
||||||
|
|
||||||
|
|
||||||
class Khaosz:
|
|
||||||
def __init__(self, model_dir: str):
|
|
||||||
with disable_random_init():
|
|
||||||
self.parameter = ModelParameter()
|
|
||||||
self.parameter.load(model_dir)
|
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
|
||||||
self.parameter.to(*args, **kwargs)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
history: List[Tuple[str, str]]=None,
|
|
||||||
temperature: float=0.8,
|
|
||||||
top_k: int=50,
|
|
||||||
top_p: float=0.95,
|
|
||||||
) -> str:
|
|
||||||
generator = LoopGenerator(self.parameter)
|
|
||||||
return generator.generate(
|
|
||||||
GenerationRequest(
|
|
||||||
top_k, top_p, temperature,
|
|
||||||
self.parameter.config.max_len,
|
|
||||||
query=query,
|
|
||||||
history=history,
|
|
||||||
build_prompt=True
|
|
||||||
))
|
|
||||||
|
|
||||||
def batch_generate(
|
|
||||||
self,
|
|
||||||
query: List[str],
|
|
||||||
history: List[Tuple[str, str]]=None,
|
|
||||||
temperature: float=0.8,
|
|
||||||
top_k: int=50,
|
|
||||||
top_p: float=0.95,
|
|
||||||
) -> List[str]:
|
|
||||||
generator = BatchGenerator(self.parameter)
|
|
||||||
return generator.generate(
|
|
||||||
GenerationRequest(
|
|
||||||
top_k, top_p, temperature,
|
|
||||||
self.parameter.config.max_len,
|
|
||||||
query=query,
|
|
||||||
history=history,
|
|
||||||
build_prompt=True
|
|
||||||
))
|
|
||||||
|
|
||||||
def stream_generate(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
history: List[Tuple[str, str]]=None,
|
|
||||||
temperature: float=0.8,
|
|
||||||
top_k: int=50,
|
|
||||||
top_p: float=0.95,
|
|
||||||
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
|
|
||||||
stream_generator = StreamGenerator(self.parameter)
|
|
||||||
return stream_generator.generate(
|
|
||||||
GenerationRequest(
|
|
||||||
top_k, top_p, temperature,
|
|
||||||
self.parameter.config.max_len,
|
|
||||||
query=query,
|
|
||||||
history=history,
|
|
||||||
build_prompt=True
|
|
||||||
))
|
|
||||||
|
|
||||||
def retrieve_generate(
|
|
||||||
self,
|
|
||||||
retrieved,
|
|
||||||
query: str,
|
|
||||||
history: List[Tuple[str, str]] = None,
|
|
||||||
temperature: float=0.8,
|
|
||||||
top_k: int=50,
|
|
||||||
top_p: float=0.95,
|
|
||||||
) -> str:
|
|
||||||
generator = LoopGenerator(self.parameter)
|
|
||||||
return generator.generate(
|
|
||||||
GenerationRequest(
|
|
||||||
top_k, top_p, temperature,
|
|
||||||
self.parameter.config.max_len,
|
|
||||||
query=query,
|
|
||||||
history=history,
|
|
||||||
system_prompt=retrieved,
|
|
||||||
build_prompt=True
|
|
||||||
))
|
|
||||||
|
|
||||||
def text_generate(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
temperature: float=0.8,
|
|
||||||
top_k: int=50,
|
|
||||||
top_p: float=0.95,
|
|
||||||
) -> str:
|
|
||||||
generator = LoopGenerator(self.parameter)
|
|
||||||
return generator.generate(
|
|
||||||
GenerationRequest(
|
|
||||||
top_k, top_p, temperature,
|
|
||||||
self.parameter.config.max_len,
|
|
||||||
query=query,
|
|
||||||
build_prompt=False
|
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
|
||||||
encoder = EmbeddingEncoder(self.parameter)
|
|
||||||
return encoder.encode(sentence)
|
|
||||||
|
|
@ -72,9 +72,12 @@ class BaseModelIO:
|
||||||
class ModelParameter(BaseModelIO):
|
class ModelParameter(BaseModelIO):
|
||||||
"""Container for model parameters with serialization capabilities."""
|
"""Container for model parameters with serialization capabilities."""
|
||||||
|
|
||||||
def save(self, save_dir: Union[str, Path]):
|
@classmethod
|
||||||
self.save_components(save_dir)
|
def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]):
|
||||||
|
instance.save_components(save_dir)
|
||||||
|
|
||||||
def load(self, load_dir: Union[str, Path]) -> "ModelParameter":
|
@classmethod
|
||||||
return self.load_components(load_dir)
|
def load(cls, load_dir: Union[str, Path]) -> "ModelParameter":
|
||||||
|
instance = cls()
|
||||||
|
return instance.load_components(load_dir)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from khaosz.inference.core import (
|
from khaosz.inference.core import (
|
||||||
|
disable_random_init,
|
||||||
GeneratorCore,
|
GeneratorCore,
|
||||||
EmbeddingEncoderCore,
|
EmbeddingEncoderCore,
|
||||||
KVCacheManager,
|
KVCacheManager,
|
||||||
|
|
@ -10,9 +11,11 @@ from khaosz.inference.generator import (
|
||||||
StreamGenerator,
|
StreamGenerator,
|
||||||
BatchGenerator,
|
BatchGenerator,
|
||||||
EmbeddingEncoder,
|
EmbeddingEncoder,
|
||||||
|
GeneratorFactory
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"disable_random_init",
|
||||||
"GeneratorCore",
|
"GeneratorCore",
|
||||||
"EmbeddingEncoderCore",
|
"EmbeddingEncoderCore",
|
||||||
"KVCacheManager",
|
"KVCacheManager",
|
||||||
|
|
@ -22,4 +25,5 @@ __all__ = [
|
||||||
"StreamGenerator",
|
"StreamGenerator",
|
||||||
"BatchGenerator",
|
"BatchGenerator",
|
||||||
"EmbeddingEncoder",
|
"EmbeddingEncoder",
|
||||||
|
"GeneratorFactory"
|
||||||
]
|
]
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, List, Tuple, Union, Optional, Self
|
from typing import Any, Callable, List, Tuple, Union, Optional, Self
|
||||||
from khaosz.config import ModelParameter, ModelConfig
|
from khaosz.config import ModelParameter, ModelConfig
|
||||||
|
|
||||||
|
|
@ -54,6 +57,26 @@ def apply_sampling_strategies(
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def disable_random_init():
|
||||||
|
init_functions = [
|
||||||
|
'xavier_normal_', 'xavier_uniform_',
|
||||||
|
'kaiming_normal_', 'kaiming_uniform_',
|
||||||
|
'zeros_', 'ones_', 'constant_',
|
||||||
|
'normal_', 'uniform_'
|
||||||
|
]
|
||||||
|
original_funcs = {}
|
||||||
|
for name in init_functions:
|
||||||
|
if hasattr(nn.init, name):
|
||||||
|
original_funcs[name] = getattr(nn.init, name)
|
||||||
|
setattr(nn.init, name, lambda *args, **kwargs: None)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for name, orig_func in original_funcs.items():
|
||||||
|
setattr(nn.init, name, orig_func)
|
||||||
|
|
||||||
|
|
||||||
class GeneratorCore:
|
class GeneratorCore:
|
||||||
def __init__(self, parameter: ModelParameter):
|
def __init__(self, parameter: ModelParameter):
|
||||||
self.model = parameter.model
|
self.model = parameter.model
|
||||||
|
|
@ -82,10 +105,6 @@ class GeneratorCore:
|
||||||
|
|
||||||
return next_token_id, cache_increase
|
return next_token_id, cache_increase
|
||||||
|
|
||||||
def to(self, *args, **kargs) -> Self:
|
|
||||||
self.model.to(*args, **kargs)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def generate_loop(
|
def generate_loop(
|
||||||
self,
|
self,
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
|
|
@ -115,6 +134,10 @@ class GeneratorCore:
|
||||||
break
|
break
|
||||||
|
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
def to(self, *args, **kargs) -> Self:
|
||||||
|
self.model.to(*args, **kargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingEncoderCore:
|
class EmbeddingEncoderCore:
|
||||||
|
|
@ -203,7 +226,7 @@ class KVCacheManager:
|
||||||
self._kv_cache: Tuple[Tensor, Tensor] = None
|
self._kv_cache: Tuple[Tensor, Tensor] = None
|
||||||
self._seq_mask: Tensor = None
|
self._seq_mask: Tensor = None
|
||||||
self._initialize()
|
self._initialize()
|
||||||
|
|
||||||
def _initialize(self):
|
def _initialize(self):
|
||||||
k_cache = torch.zeros(
|
k_cache = torch.zeros(
|
||||||
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
|
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
|
||||||
|
|
|
||||||
|
|
@ -9,33 +9,37 @@ from khaosz.config.param_config import ModelParameter
|
||||||
HistoryType = List[Tuple[str, str]]
|
HistoryType = List[Tuple[str, str]]
|
||||||
|
|
||||||
def build_prompt(
|
def build_prompt(
|
||||||
query: str,
|
query: str,
|
||||||
init_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
history: Optional[List[Tuple[str, str]]] = None
|
history: Optional[HistoryType] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
|
||||||
Build prompt in ChatML format for query and history
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query(str): query string
|
|
||||||
history(Optional[List[Tuple[str, str]]]): history list of query and response
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: prompt string in ChatML format
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
prompt = f"<|im_start|>system\n{init_prompt}<|im_end|>\n" if init_prompt else ""
|
Build prompt in ChatML format for query and history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): query string.
|
||||||
|
system_prompt (Optional[str]): system prompt string.
|
||||||
|
history (Optional[HistoryType]): history list of query and response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: prompt string in ChatML format.
|
||||||
|
"""
|
||||||
|
result = ""
|
||||||
|
|
||||||
|
if system_prompt:
|
||||||
|
result += f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
|
||||||
|
|
||||||
# (convert tuple format to ChatML)
|
# (convert tuple format to ChatML)
|
||||||
if history:
|
if history:
|
||||||
for user_msg, assistant_msg in history:
|
for user_msg, assistant_msg in history:
|
||||||
prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
|
result += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
|
||||||
prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
|
result += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
|
||||||
|
|
||||||
prompt += f"<|im_start|>user\n{query}<|im_end|>\n"
|
result += f"<|im_start|>user\n{query}<|im_end|>\n"
|
||||||
prompt += "<|im_start|>assistant\n"
|
result += "<|im_start|>assistant\n"
|
||||||
|
|
||||||
return prompt
|
return result
|
||||||
|
|
||||||
|
|
||||||
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
|
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -59,8 +63,21 @@ def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]
|
||||||
|
|
||||||
return new_ids_list, max_ids_len
|
return new_ids_list, max_ids_len
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GenerationRequest:
|
class GenerationRequest:
|
||||||
|
"""
|
||||||
|
Request parameters for text generation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
top_k: Top-k sampling parameter.
|
||||||
|
top_p: Top-p (nucleus) sampling parameter.
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
max_len: Maximum generation length.
|
||||||
|
query: Input query (string or list of strings for batch).
|
||||||
|
history: Conversation history.
|
||||||
|
system_prompt: System prompt for the conversation.
|
||||||
|
"""
|
||||||
top_k: int
|
top_k: int
|
||||||
top_p: float
|
top_p: float
|
||||||
temperature: float
|
temperature: float
|
||||||
|
|
@ -70,8 +87,6 @@ class GenerationRequest:
|
||||||
history: Optional[Union[HistoryType, List[HistoryType]]] = None
|
history: Optional[Union[HistoryType, List[HistoryType]]] = None
|
||||||
system_prompt: Optional[str] = None
|
system_prompt: Optional[str] = None
|
||||||
|
|
||||||
build_prompt: bool = True
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not isinstance(self.top_k, int) or self.top_k < 0:
|
if not isinstance(self.top_k, int) or self.top_k < 0:
|
||||||
raise ValueError("top_k must be a non-negative integer")
|
raise ValueError("top_k must be a non-negative integer")
|
||||||
|
|
@ -89,19 +104,21 @@ class LoopGenerator(GeneratorCore):
|
||||||
device = next(self.model.parameters()).device
|
device = next(self.model.parameters()).device
|
||||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||||
|
|
||||||
input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query
|
prompt = build_prompt(request.query, request.history)
|
||||||
ids = self.tokenizer.encode(input_args)
|
ids = self.tokenizer.encode(prompt)
|
||||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||||
|
|
||||||
start_cache_pos = len(ids)
|
start_cache_pos = len(ids)
|
||||||
cur_cache_pos = 0
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
kv_caches = cache_manager.get_kvcache()
|
kv_caches = cache_manager.get_kvcache()
|
||||||
|
|
||||||
ids = self.generate_loop(
|
ids = self.generate_loop(
|
||||||
input_ids, ids, request.temperature, request.top_k, request.top_p,
|
input_ids,
|
||||||
|
ids,
|
||||||
|
request.temperature,
|
||||||
|
request.top_k,
|
||||||
|
request.top_p,
|
||||||
kv_caches=kv_caches,
|
kv_caches=kv_caches,
|
||||||
start_pos=cur_cache_pos
|
|
||||||
)
|
)
|
||||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||||
|
|
||||||
|
|
@ -112,16 +129,12 @@ class StreamGenerator(GeneratorCore):
|
||||||
def __init__(self, parameter: ModelParameter):
|
def __init__(self, parameter: ModelParameter):
|
||||||
super().__init__(parameter)
|
super().__init__(parameter)
|
||||||
|
|
||||||
def generate(self, request: GenerationRequest) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
|
def generate(self, request: GenerationRequest) -> Generator[str, None, None]:
|
||||||
|
|
||||||
if request.history is None:
|
|
||||||
request.history = []
|
|
||||||
|
|
||||||
device = next(self.model.parameters()).device
|
device = next(self.model.parameters()).device
|
||||||
cache_manager = KVCacheManager(self.config, 1, device=device)
|
cache_manager = KVCacheManager(self.config, 1, device=device)
|
||||||
|
|
||||||
input_args = build_prompt(request.query, request.history) if request.build_prompt else request.query
|
prompt = build_prompt(request.query, request.history)
|
||||||
ids = self.tokenizer.encode(input_args)
|
ids = self.tokenizer.encode(prompt)
|
||||||
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
|
||||||
|
|
||||||
start_cache_pos = len(ids)
|
start_cache_pos = len(ids)
|
||||||
|
|
@ -141,10 +154,10 @@ class StreamGenerator(GeneratorCore):
|
||||||
cur_cache_pos += cache_increase
|
cur_cache_pos += cache_increase
|
||||||
|
|
||||||
response = self.tokenizer.decode(ids[start_cache_pos:])
|
response = self.tokenizer.decode(ids[start_cache_pos:])
|
||||||
yield response, request.history + [(request.query, response)]
|
yield response
|
||||||
|
|
||||||
if next_token_id.item() in self.tokenizer.stop_ids:
|
if next_token_id.item() in self.tokenizer.stop_ids:
|
||||||
yield response + "\n", request.history + [(request.query, response)]
|
yield response + "\n"
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -217,4 +230,36 @@ class EmbeddingEncoder(EmbeddingEncoderCore):
|
||||||
|
|
||||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||||
return super().encode(sentence)
|
return super().encode(sentence)
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorFactory:
|
||||||
|
"""Factory class for creating appropriate generator instances based on request features."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_generator(parameter: ModelParameter, request: GenerationRequest):
|
||||||
|
"""
|
||||||
|
Create a generator based on the characteristics of GenerationRequest.
|
||||||
|
Args:
|
||||||
|
parameter: Model parameters
|
||||||
|
request: Generation request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Subclass instance of GeneratorCore
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Streaming generation detection: check stream field
|
||||||
|
if request.stream:
|
||||||
|
return StreamGenerator(parameter)
|
||||||
|
|
||||||
|
# Batch generation detection: query is a list
|
||||||
|
if isinstance(request.query, list):
|
||||||
|
return BatchGenerator(parameter)
|
||||||
|
|
||||||
|
# Default return LoopGenerator
|
||||||
|
return LoopGenerator(parameter)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_encoder(parameter: ModelParameter):
|
||||||
|
"""Create an EmbeddingEncoder instance"""
|
||||||
|
return EmbeddingEncoder(parameter)
|
||||||
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
# init file
|
|
||||||
|
|
@ -1,88 +0,0 @@
|
||||||
import torch
|
|
||||||
import sqlite3
|
|
||||||
import numpy as np
|
|
||||||
from torch import Tensor
|
|
||||||
from typing import Dict, List, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
class Retriever:
|
|
||||||
def __init__(self, db_path=None):
|
|
||||||
self.data: Dict[str, Tensor] = {}
|
|
||||||
self.embedding_cache: Tensor = None
|
|
||||||
self.is_caculated: bool = False
|
|
||||||
|
|
||||||
if db_path is not None:
|
|
||||||
self.load(db_path)
|
|
||||||
|
|
||||||
def retrieve(self, query: Tensor, top_k: int) -> List[Tuple[str, float]]:
|
|
||||||
if not self.data:
|
|
||||||
return []
|
|
||||||
|
|
||||||
query = query.flatten().unsqueeze(1) # [dim, 1]
|
|
||||||
norm_embeddings = self._embeddings.to(
|
|
||||||
device=query.device,
|
|
||||||
dtype=query.dtype
|
|
||||||
) # [n_vectors, dim]
|
|
||||||
sim_scores = torch.matmul(norm_embeddings, query).squeeze() # [n_vectors]
|
|
||||||
|
|
||||||
top_k = min(top_k, len(self.data))
|
|
||||||
indices = sim_scores.topk(top_k).indices
|
|
||||||
keys = list(self.data.keys())
|
|
||||||
|
|
||||||
return [(keys[i], sim_scores[i].item()) for i in indices]
|
|
||||||
|
|
||||||
def add_vector(self, key: str, vector_data: Tensor):
|
|
||||||
self.is_caculated = False
|
|
||||||
self.data[key] = vector_data.flatten().float().cpu()
|
|
||||||
|
|
||||||
def delete_vector(self, key: str):
|
|
||||||
self.is_caculated = False
|
|
||||||
self.data.pop(key, None)
|
|
||||||
|
|
||||||
def save(self, db_path):
|
|
||||||
conn = sqlite3.connect(db_path)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
self._init_db(cursor)
|
|
||||||
cursor.execute('DELETE FROM vectors')
|
|
||||||
|
|
||||||
for item, vec in self.data.items():
|
|
||||||
vec_bytes = vec.numpy().tobytes()
|
|
||||||
cursor.execute('INSERT OR REPLACE INTO vectors (key, vector) VALUES (?, ?)',
|
|
||||||
(item, vec_bytes))
|
|
||||||
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def load(self, db_path):
|
|
||||||
conn = sqlite3.connect(db_path)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
self._init_db(cursor)
|
|
||||||
cursor.execute('SELECT key, vector FROM vectors')
|
|
||||||
rows = cursor.fetchall()
|
|
||||||
self.data = {}
|
|
||||||
|
|
||||||
for row in rows:
|
|
||||||
key, vec_bytes = row
|
|
||||||
vec_numpy = np.frombuffer(vec_bytes, dtype=np.float32).copy()
|
|
||||||
vec = torch.from_numpy(vec_numpy)
|
|
||||||
self.data[key] = vec
|
|
||||||
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def _init_db(self,cursor: sqlite3.Cursor):
|
|
||||||
# Create table if not exists (in case loading from a new database)
|
|
||||||
cursor.execute('''
|
|
||||||
CREATE TABLE IF NOT EXISTS vectors (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
key TEXT UNIQUE NOT NULL,
|
|
||||||
vector BLOB NOT NULL
|
|
||||||
)''')
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _embeddings(self) -> Tensor:
|
|
||||||
if not self.is_caculated:
|
|
||||||
embeddings = torch.stack(list(self.data.values()))
|
|
||||||
norm_embeddings = embeddings / torch.norm(embeddings, dim=-1, keepdim=True)
|
|
||||||
self.embedding_cache = norm_embeddings
|
|
||||||
|
|
||||||
return self.embedding_cache
|
|
||||||
|
|
@ -1,127 +0,0 @@
|
||||||
import re
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from torch import Tensor
|
|
||||||
from typing import List, Callable, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTextSplitter(ABC):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_len: int = 512,
|
|
||||||
chunk_overlap: int = 0,
|
|
||||||
):
|
|
||||||
if max_len <= 0:
|
|
||||||
raise ValueError("max_len must be > 0")
|
|
||||||
if chunk_overlap < 0:
|
|
||||||
raise ValueError("chunk_overlap must be >= 0")
|
|
||||||
|
|
||||||
self.max_len = max_len
|
|
||||||
self.chunk_overlap = chunk_overlap
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def split(self, text: str, **kwargs) -> List[str]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def preprocess(self, text: str) -> str:
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
def postprocess(self, chunks: List[str]) -> List[str]:
|
|
||||||
return [chunk.strip() for chunk in chunks if chunk.strip()]
|
|
||||||
|
|
||||||
|
|
||||||
class PriorityTextSplitter(BaseTextSplitter):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
separators: List[str],
|
|
||||||
max_len: int = 512,
|
|
||||||
chunk_overlap: int = 0,
|
|
||||||
):
|
|
||||||
super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
|
|
||||||
if not separators:
|
|
||||||
raise ValueError("separators must be a non-empty list")
|
|
||||||
self.separators = separators
|
|
||||||
|
|
||||||
def split(self, text: str) -> List[str]:
|
|
||||||
text = self.preprocess(text)
|
|
||||||
for sep in self.separators:
|
|
||||||
parts = text.split(sep)
|
|
||||||
|
|
||||||
valid_parts = [p.strip() for p in parts if p.strip()]
|
|
||||||
if len(valid_parts) > 1:
|
|
||||||
return self.postprocess(valid_parts)
|
|
||||||
return [text]
|
|
||||||
|
|
||||||
|
|
||||||
class SemanticTextSplitter(BaseTextSplitter):
|
|
||||||
|
|
||||||
DEFAULT_PATTERN = r'(?<=[。!?!?])(?=(?:[^"\'‘’“”]*["\'‘’“”][^"\'‘’“”]*["\'‘’“”])*[^"\'‘’“”]*$)'
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embedding_func: Callable[[List[str]], List[Tensor]],
|
|
||||||
pattern: Optional[str] = None,
|
|
||||||
max_len: int = 512,
|
|
||||||
chunk_overlap: int = 0,
|
|
||||||
):
|
|
||||||
super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
|
|
||||||
if not callable(embedding_func):
|
|
||||||
raise TypeError("embedding_func must be callable")
|
|
||||||
self.embedding_func = embedding_func
|
|
||||||
self.pattern = pattern or SemanticTextSplitter.DEFAULT_PATTERN
|
|
||||||
|
|
||||||
def split(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
threshold: float = 0.5,
|
|
||||||
window_size: int = 1,
|
|
||||||
) -> List[str]:
|
|
||||||
text = self.preprocess(text)
|
|
||||||
sentences = [s.strip() for s in re.split(self.pattern, text) if s.strip()]
|
|
||||||
|
|
||||||
if len(sentences) <= 1:
|
|
||||||
return self.postprocess(sentences)
|
|
||||||
|
|
||||||
try:
|
|
||||||
sentence_embs = self.embedding_func(sentences)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Embedding generation failed: {e}")
|
|
||||||
|
|
||||||
if len(sentence_embs) != len(sentences):
|
|
||||||
raise ValueError("Embedding function must return one vector per sentence")
|
|
||||||
|
|
||||||
chunks = []
|
|
||||||
emb_tensor = torch.stack(sentence_embs) # shape: [N, D]
|
|
||||||
current_chunk: List[str] = [sentences[0]]
|
|
||||||
|
|
||||||
for i in range(1, len(sentences)):
|
|
||||||
start_prev = max(0, i - window_size)
|
|
||||||
end_prev = i
|
|
||||||
start_next = i
|
|
||||||
end_next = min(len(sentences), i + window_size)
|
|
||||||
|
|
||||||
prev_window_emb = emb_tensor[start_prev:end_prev].mean(dim=0)
|
|
||||||
next_window_emb = emb_tensor[start_next:end_next].mean(dim=0)
|
|
||||||
|
|
||||||
similarity = F.cosine_similarity(
|
|
||||||
prev_window_emb.unsqueeze(0),
|
|
||||||
next_window_emb.unsqueeze(0),
|
|
||||||
dim=1
|
|
||||||
).item()
|
|
||||||
|
|
||||||
dynamic_threshold = max(threshold * (1 - 0.03 * (end_next - start_prev)), 0.2)
|
|
||||||
|
|
||||||
if similarity < dynamic_threshold:
|
|
||||||
chunks.append(" ".join(current_chunk))
|
|
||||||
overlap_start = max(0, len(current_chunk) - self.chunk_overlap)
|
|
||||||
current_chunk = current_chunk[overlap_start:]
|
|
||||||
current_chunk.append(sentences[i])
|
|
||||||
else:
|
|
||||||
current_chunk.append(sentences[i])
|
|
||||||
|
|
||||||
if current_chunk:
|
|
||||||
chunks.append(" ".join(current_chunk))
|
|
||||||
|
|
||||||
return self.postprocess(chunks)
|
|
||||||
|
|
@ -54,7 +54,7 @@ def test_env(request: pytest.FixtureRequest):
|
||||||
def test_model_parameter(test_env):
|
def test_model_parameter(test_env):
|
||||||
save_dir = os.path.join(test_env["test_dir"], "save")
|
save_dir = os.path.join(test_env["test_dir"], "save")
|
||||||
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])
|
model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"])
|
||||||
model_param.save(save_dir)
|
ModelParameter.save(model_param, save_dir)
|
||||||
|
|
||||||
assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
|
assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
|
||||||
assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))
|
assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))
|
||||||
|
|
|
||||||
|
|
@ -1,45 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
import json
|
import json
|
||||||
import torch
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from khaosz import Khaosz
|
from khaosz.config.param_config import ModelParameter
|
||||||
from typing import List
|
from khaosz.inference.generator import BatchGenerator, GenerationRequest
|
||||||
from tqdm import tqdm
|
from khaosz.inference.core import disable_random_init
|
||||||
|
|
||||||
|
|
||||||
def batch_generate(
|
|
||||||
model: Khaosz,
|
|
||||||
query: List[str],
|
|
||||||
temperature: float,
|
|
||||||
top_k: int,
|
|
||||||
top_p: float,
|
|
||||||
batch_size: int,
|
|
||||||
) -> List:
|
|
||||||
assert batch_size > 0
|
|
||||||
sorted_query = sorted(query, key=lambda x: len(x), reverse=True)
|
|
||||||
original_indices = {query: idx for idx, query in enumerate(query)}
|
|
||||||
|
|
||||||
responses = [None] * len(query)
|
|
||||||
total_batches = (len(sorted_query) + batch_size - 1) // batch_size
|
|
||||||
|
|
||||||
for i in tqdm(range(0, total_batches * batch_size, batch_size), desc="Generating responses"):
|
|
||||||
batch_query = sorted_query[i: min(i + batch_size, len(query))]
|
|
||||||
if not isinstance(batch_query, list):
|
|
||||||
batch_query = [batch_query]
|
|
||||||
|
|
||||||
batch_responses = model.batch_generate(
|
|
||||||
query=batch_query,
|
|
||||||
temperature=temperature,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p
|
|
||||||
)
|
|
||||||
|
|
||||||
for query, response in zip(batch_query, batch_responses):
|
|
||||||
original_idx = original_indices[query]
|
|
||||||
responses[original_idx] = response
|
|
||||||
|
|
||||||
return responses
|
|
||||||
|
|
||||||
|
|
||||||
def processor(
|
def processor(
|
||||||
|
|
@ -53,24 +18,31 @@ def processor(
|
||||||
question_key: str,
|
question_key: str,
|
||||||
response_key: str,
|
response_key: str,
|
||||||
):
|
):
|
||||||
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
|
with disable_random_init():
|
||||||
|
param = ModelParameter.load(model_dir)
|
||||||
|
|
||||||
|
param.to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
generator = BatchGenerator(param)
|
||||||
|
|
||||||
with open(input_json_file, "r", encoding='utf-8') as f:
|
with open(input_json_file, "r", encoding='utf-8') as f:
|
||||||
input_data = [json.loads(line) for line in f]
|
input_data = [json.loads(line) for line in f]
|
||||||
query = [item[question_key] for item in input_data]
|
|
||||||
|
queries = [item[question_key] for item in input_data]
|
||||||
responses = batch_generate(
|
|
||||||
model=model,
|
request = GenerationRequest(
|
||||||
query=query,
|
query=queries,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
batch_size=batch_size
|
top_k=top_k,
|
||||||
|
max_len=param.config.max_len,
|
||||||
|
history=None,
|
||||||
|
system_prompt=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Write output in JSONL format
|
responses = generator.generate(request)
|
||||||
|
|
||||||
with open(output_json_file, "w", encoding='utf-8') as f:
|
with open(output_json_file, "w", encoding='utf-8') as f:
|
||||||
for query, response in zip(query, responses):
|
for query, response in zip(queries, responses):
|
||||||
output_item = {question_key: query, response_key: response}
|
output_item = {question_key: query, response_key: response}
|
||||||
f.write(json.dumps(output_item, ensure_ascii=False) + '\n')
|
f.write(json.dumps(output_item, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
|
@ -89,4 +61,6 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.")
|
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
processor(**vars(args))
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
processor(**vars(args))
|
||||||
|
|
@ -6,7 +6,9 @@ import argparse
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from khaosz import Khaosz
|
from khaosz.config.param_config import ModelParameter
|
||||||
|
from khaosz.inference.core import disable_random_init
|
||||||
|
|
||||||
|
|
||||||
def compute_perplexity(
|
def compute_perplexity(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
|
@ -45,22 +47,23 @@ def process_file(
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
text_key: str
|
text_key: str
|
||||||
):
|
):
|
||||||
model = Khaosz(model_dir).to(device="cuda", dtype=torch.bfloat16)
|
with disable_random_init():
|
||||||
tokenizer = model.parameter.tokenizer
|
param = ModelParameter.load(model_dir)
|
||||||
|
|
||||||
|
param.to(device='cuda', dtype=torch.bfloat16)
|
||||||
|
model = param.model
|
||||||
|
tokenizer = param.tokenizer
|
||||||
|
|
||||||
with open(input_file, "r", encoding='utf-8') as f:
|
with open(input_file, "r", encoding='utf-8') as f:
|
||||||
input_data = [json.loads(line) for line in f]
|
input_data = [json.loads(line) for line in f]
|
||||||
|
|
||||||
texts = [item[text_key] for item in input_data]
|
texts = [item[text_key] for item in input_data]
|
||||||
encoded_texts = [tokenizer.encode(text) for text in texts]
|
encoded_texts = [tokenizer.encode(text) for text in texts]
|
||||||
|
|
||||||
output_data = []
|
output_data = []
|
||||||
|
|
||||||
for i in tqdm(range(0, len(encoded_texts), batch_size), desc="Computing perplexity"):
|
for i in tqdm(range(0, len(encoded_texts), batch_size), desc="Computing perplexity"):
|
||||||
batch_encoded = encoded_texts[i:i + batch_size]
|
batch_encoded = encoded_texts[i:i + batch_size]
|
||||||
batch_texts = texts[i:i + batch_size]
|
batch_texts = texts[i:i + batch_size]
|
||||||
|
|
||||||
# Pad sequences to the same length (left padding)
|
|
||||||
max_len = max(len(seq) for seq in batch_encoded)
|
max_len = max(len(seq) for seq in batch_encoded)
|
||||||
padded_ids = []
|
padded_ids = []
|
||||||
masks = []
|
masks = []
|
||||||
|
|
@ -74,10 +77,7 @@ def process_file(
|
||||||
|
|
||||||
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
|
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
|
||||||
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
|
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
|
||||||
|
perplexity = compute_perplexity(model, input_ids, input_mask)
|
||||||
# Compute perplexity
|
|
||||||
with torch.inference_mode():
|
|
||||||
perplexity = compute_perplexity(model.parameter.model, input_ids, input_mask)
|
|
||||||
|
|
||||||
for text, ppl in zip(batch_texts, perplexity):
|
for text, ppl in zip(batch_texts, perplexity):
|
||||||
output_data.append({text_key: text, "ppl": float(ppl.item())})
|
output_data.append({text_key: text, "ppl": float(ppl.item())})
|
||||||
|
|
@ -87,16 +87,14 @@ def process_file(
|
||||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
|
||||||
def main():
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
||||||
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
|
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
|
||||||
parser.add_argument("--input_file", type=str, required=True, help="Path to the input file.")
|
parser.add_argument("--input_file", type=str, required=True, help="Path to the input file.")
|
||||||
parser.add_argument("--output_file", type=str, required=True, help="Path to the output file.")
|
parser.add_argument("--output_file", type=str, required=True, help="Path to the output file.")
|
||||||
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation.")
|
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation.")
|
||||||
parser.add_argument("--text_key", type=str, default="text", help="Key for the text field in the input data.")
|
parser.add_argument("--text_key", type=str, default="text", help="Key for the text field in the input data.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
process_file(**vars(args))
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
with torch.inference_mode():
|
||||||
main()
|
process_file(**vars(args))
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||||
|
|
||||||
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
|
parser.add_argument("--train_type", type=str, required=True, choices=["seq", "sft", "dpo"], help="Train type.")
|
||||||
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
|
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
|
||||||
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
|
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
|
||||||
|
|
||||||
|
|
@ -67,18 +67,14 @@ def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler
|
||||||
return SchedulerFactory.load(optimizer, **kwargs)
|
return SchedulerFactory.load(optimizer, **kwargs)
|
||||||
|
|
||||||
def prepare_checkpoint(model: nn.Module) -> dict:
|
def prepare_checkpoint(model: nn.Module) -> dict:
|
||||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
return model.module.state_dict()
|
||||||
state_dict = model.module.state_dict()
|
|
||||||
else:
|
|
||||||
state_dict = model.state_dict()
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
train_type: str,
|
train_type: str,
|
||||||
param_path: str,
|
param_path: str,
|
||||||
data_root_path: str,
|
data_root_path: str,
|
||||||
max_lr: int,
|
max_lr: float,
|
||||||
n_epoch: int,
|
n_epoch: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
start_epoch: int,
|
start_epoch: int,
|
||||||
|
|
@ -104,8 +100,7 @@ def train(
|
||||||
assert train_type in ["seq", "sft", "dpo"]
|
assert train_type in ["seq", "sft", "dpo"]
|
||||||
assert os.path.exists(param_path)
|
assert os.path.exists(param_path)
|
||||||
|
|
||||||
parameter = ModelParameter()
|
parameter = ModelParameter.load(param_path)
|
||||||
parameter.load(param_path)
|
|
||||||
|
|
||||||
if window_size is None:
|
if window_size is None:
|
||||||
window_size = parameter.config.max_len
|
window_size = parameter.config.max_len
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue