AstrAI/demo/stream_chat.py

48 lines
1.3 KiB
Python

import torch
from pathlib import Path
from astrai.config.param_config import ModelParameter
from astrai.inference.core import disable_random_init
from astrai.inference.generator import GeneratorFactory, GenerationRequest
PROJECT_ROOT = Path(__file__).parent.parent
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def chat():
with disable_random_init():
param = ModelParameter.load(PARAMETER_ROOT)
param.to(device="cuda", dtype=torch.bfloat16)
history = []
while True:
query = input(">> ")
if query == "!exit":
break
request = GenerationRequest(
query=query,
temperature=0.8,
top_p=0.95,
top_k=50,
max_len=param.config.max_len,
history=history,
system_prompt=None,
)
generator = GeneratorFactory.create(param, request)
response_size = 0
full_response = ""
for response in generator.generate(request):
# response is the cumulative response up to current token
print(response[response_size:], end="", flush=True)
response_size = len(response)
full_response = response
# After generation, update history
history.append((query, full_response.strip()))
if __name__ == "__main__":
chat()