51 lines
1.3 KiB
Python
51 lines
1.3 KiB
Python
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
from astrai.inference import InferenceEngine
|
|
from astrai.model import AutoModel
|
|
from astrai.tokenize import AutoTokenizer
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
|
|
|
|
|
def chat():
|
|
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
|
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
|
model.to(device="cuda", dtype=torch.bfloat16)
|
|
|
|
messages = []
|
|
engine = InferenceEngine(model=model, tokenizer=tokenizer)
|
|
|
|
while True:
|
|
query = input(">> ")
|
|
if query == "!exit":
|
|
break
|
|
|
|
# Add user message
|
|
messages.append({"role": "user", "content": query})
|
|
|
|
# Generate response
|
|
full_response = ""
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
|
|
|
|
for token in engine.generate(
|
|
prompt=prompt,
|
|
stream=True,
|
|
max_tokens=2048,
|
|
temperature=0.8,
|
|
top_p=0.95,
|
|
top_k=50,
|
|
):
|
|
print(token, end="", flush=True)
|
|
full_response += token
|
|
|
|
print()
|
|
# Add assistant response to messages
|
|
messages.append({"role": "assistant", "content": full_response.strip()})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
chat()
|