diff --git a/scripts/generate_ar.py b/scripts/generate_ar.py index 23b2832..77c216c 100644 --- a/scripts/generate_ar.py +++ b/scripts/generate_ar.py @@ -14,9 +14,9 @@ def generate_text(): response = model.text_generate( query=query, - temperature=0.6, + temperature=0.8, top_p=0.95, - top_k=30 + top_k=50 ) print(response) diff --git a/scripts/generate_batch.py b/scripts/generate_batch.py index 037b04f..43ece47 100644 --- a/scripts/generate_batch.py +++ b/scripts/generate_batch.py @@ -13,9 +13,9 @@ def batch_generate(): responses = model.batch_generate( queries=inputs, - temperature=0.7, + temperature=0.8, top_p=0.95, - top_k=30 + top_k=50 ) for q, r in zip(inputs, responses): diff --git a/scripts/generate_retrieve.py b/scripts/generate_retrieve.py index 46a9e56..6ba4197 100644 --- a/scripts/generate_retrieve.py +++ b/scripts/generate_retrieve.py @@ -30,9 +30,9 @@ if __name__ == "__main__": retrive_response = model.retrieve_generate( retrieved=retrieved, query=query, - temperature=0.7, - top_k=30, + temperature=0.8, top_p=0.95, + top_k=50 ) print("retrive content:") diff --git a/scripts/stream_chat.py b/scripts/stream_chat.py index 399717b..112720d 100644 --- a/scripts/stream_chat.py +++ b/scripts/stream_chat.py @@ -20,9 +20,9 @@ def chat(): for response, histroy in model.stream_generate( query=query, history=histroy, - temperature=0.7, + temperature=0.8, top_p=0.95, - top_k=30 + top_k=50 ): print(response[response_size:], end="", flush=True) response_size = len(response)