AstrAI/scripts/demo/stream_chat.py

43 lines
956 B
Python

from pathlib import Path
import torch
from astrai.config.param_config import ModelParameter
from astrai.inference import InferenceEngine
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def chat():
param = ModelParameter.load(PARAMETER_ROOT, disable_init=True)
param.to(device="cuda", dtype=torch.bfloat16)
history = []
engine = InferenceEngine(param)
while True:
query = input(">> ")
if query == "!exit":
break
full_response = ""
for token in engine.generate(
prompt=query,
stream=True,
max_tokens=param.config.max_len,
temperature=0.8,
top_p=0.95,
top_k=50,
):
print(token, end="", flush=True)
full_response += token
print()
history.append((query, full_response.strip()))
if __name__ == "__main__":
chat()