AstrAI/scripts/demo/stream_chat.py

51 lines
1.3 KiB
Python

from pathlib import Path
import torch
from astrai.inference import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def chat():
model = AutoModel.from_pretrained(PARAMETER_ROOT)
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16)
messages = []
engine = InferenceEngine(model=model, tokenizer=tokenizer)
while True:
query = input(">> ")
if query == "!exit":
break
# Add user message
messages.append({"role": "user", "content": query})
# Generate response
full_response = ""
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
for token in engine.generate(
prompt=prompt,
stream=True,
max_tokens=model.config.max_len,
temperature=0.8,
top_p=0.95,
top_k=50,
):
print(token, end="", flush=True)
full_response += token
print()
# Add assistant response to messages
messages.append({"role": "assistant", "content": full_response.strip()})
if __name__ == "__main__":
chat()