41 lines
887 B
Python
41 lines
887 B
Python
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
from astrai.inference import InferenceEngine
|
|
from astrai.model import AutoModel
|
|
from astrai.tokenize import AutoTokenizer
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
|
|
|
|
|
def generate_text():
|
|
# Load model from pretrained
|
|
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
|
model.to(device="cuda", dtype=torch.bfloat16)
|
|
|
|
# Load tokenizer from pretrained
|
|
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT / "tokenizer")
|
|
|
|
query = input(">> ")
|
|
|
|
engine = InferenceEngine(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
)
|
|
response = engine.generate(
|
|
prompt=query,
|
|
stream=False,
|
|
max_tokens=2048,
|
|
temperature=0.8,
|
|
top_p=0.95,
|
|
top_k=50,
|
|
)
|
|
|
|
print(response)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
generate_text()
|