AstrAI/scripts/demo/generate_ar.py

41 lines
887 B
Python

from pathlib import Path
import torch
from astrai.inference import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def generate_text():
# Load model from pretrained
model = AutoModel.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16)
# Load tokenizer from pretrained
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT / "tokenizer")
query = input(">> ")
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
)
response = engine.generate(
prompt=query,
stream=False,
max_tokens=2048,
temperature=0.8,
top_p=0.95,
top_k=50,
)
print(response)
if __name__ == "__main__":
generate_text()