112 lines
3.2 KiB
Python
112 lines
3.2 KiB
Python
from torch import Tensor
|
|
from typing import List, Tuple, Generator, Union
|
|
|
|
from khaosz.core.generator import (
|
|
TextGenerator,
|
|
ChatGenerator,
|
|
StreamGenerator,
|
|
BatchGenerator,
|
|
RetrievalGenerator,
|
|
EmbeddingEncoder
|
|
)
|
|
from khaosz.core.parameter import ParameterLoader
|
|
|
|
|
|
class Khaosz:
|
|
def __init__(self, model_dir: str):
|
|
self.parameter = ParameterLoader.load(model_dir)
|
|
|
|
def to(self, *args, **kwargs):
|
|
self.parameter.to(*args, **kwargs)
|
|
return self
|
|
|
|
def generate(
|
|
self,
|
|
query: str,
|
|
history: List[Tuple[str, str]]=None,
|
|
temperature: float=0.8,
|
|
top_k: int=50,
|
|
top_p: float=0.95,
|
|
) -> str:
|
|
generator = ChatGenerator(self.parameter)
|
|
return generator.generate(
|
|
query,
|
|
history=history,
|
|
temperature=temperature,
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
)
|
|
|
|
def batch_generate(
|
|
self,
|
|
queries: List[str],
|
|
histories: List[Tuple[str, str]]=None,
|
|
temperature: float=0.8,
|
|
top_k: int=50,
|
|
top_p: float=0.95,
|
|
) -> List[str]:
|
|
generator = BatchGenerator(self.parameter)
|
|
return generator.generate(
|
|
queries,
|
|
histories=histories,
|
|
temperature=temperature,
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
)
|
|
|
|
|
|
def stream_generate(
|
|
self,
|
|
query: str,
|
|
history: List[Tuple[str, str]]=None,
|
|
temperature: float=0.8,
|
|
top_k: int=50,
|
|
top_p: float=0.95,
|
|
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
|
|
stream_generator = StreamGenerator(self.parameter)
|
|
return stream_generator.generate(
|
|
query,
|
|
history=history,
|
|
temperature=temperature,
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
)
|
|
|
|
def retrieve_generate(
|
|
self,
|
|
retrieved,
|
|
query: str,
|
|
history: List[Tuple[str, str]] = None,
|
|
temperature: float=0.8,
|
|
top_k: int=50,
|
|
top_p: float=0.95,
|
|
) -> str:
|
|
generator = RetrievalGenerator(self.parameter)
|
|
return generator.generate(
|
|
retrieved,
|
|
query,
|
|
history=history,
|
|
temperature=temperature,
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
)
|
|
|
|
def text_generate(
|
|
self,
|
|
query: str,
|
|
temperature: float=0.8,
|
|
top_k: int=50,
|
|
top_p: float=0.95,
|
|
) -> str:
|
|
generator = TextGenerator(self.parameter)
|
|
|
|
return generator.generate(
|
|
query,
|
|
temperature=temperature,
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
)
|
|
|
|
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
|
encoder = EmbeddingEncoder(self.parameter)
|
|
return encoder.encode(sentence) |