36 lines
881 B
Python
36 lines
881 B
Python
import torch
|
|
from pathlib import Path
|
|
from khaosz.config.param_config import ModelParameter
|
|
from khaosz.inference.core import disable_random_init
|
|
from khaosz.inference.generator import GeneratorFactory, GenerationRequest
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent
|
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
|
|
|
|
|
def generate_text():
|
|
|
|
with disable_random_init():
|
|
param = ModelParameter.load(PARAMETER_ROOT)
|
|
param.to(device="cuda", dtype=torch.bfloat16)
|
|
|
|
query = input(">> ")
|
|
|
|
request = GenerationRequest(
|
|
query=query,
|
|
temperature=0.8,
|
|
top_p=0.95,
|
|
top_k=50,
|
|
max_len=param.config.max_len,
|
|
history=None,
|
|
system_prompt=None,
|
|
)
|
|
generator = GeneratorFactory.create(param, request)
|
|
response = generator.generate(request)
|
|
|
|
print(response)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
generate_text()
|