From 912d7c7f545e410bb54000a288d4ba1b27cd76e7 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 2 Apr 2026 15:40:31 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E6=9B=B4=E6=96=B0=E8=84=9A=E6=9C=AC?= =?UTF-8?q?=E5=B9=B6=E4=B8=94=E4=BF=AE=E6=94=B9gitignore?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 18 +++++++++++------- astrai/config/param_config.py | 2 +- scripts/demo/download.py | 2 +- scripts/demo/generate_ar.py | 2 +- scripts/demo/generate_batch.py | 2 +- scripts/demo/stream_chat.py | 3 ++- 6 files changed, 17 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index f8f1d17..def3a3f 100644 --- a/.gitignore +++ b/.gitignore @@ -6,11 +6,15 @@ # Allow specific file types and root files !*.py -!*.md -!*.png -!LICENSE -!pyproject.toml -!.github/ISSUE_TEMPLATE/* -!.github/workflows/lint.yml -!.github/workflows/tests.yml \ No newline at end of file +# Allow GitHub files +!/.github/ISSUE_TEMPLATE/* +!/.github/workflows/lint.yml +!/.github/workflows/tests.yml + +# Allow root files +!/assets/* +!/CONTRIBUTING.md +!/LICENSE +!/pyproject.toml +!/README.md \ No newline at end of file diff --git a/astrai/config/param_config.py b/astrai/config/param_config.py index 8d6f33d..c5d7076 100644 --- a/astrai/config/param_config.py +++ b/astrai/config/param_config.py @@ -81,7 +81,7 @@ class BaseModelIO: self.config.load(str(paths["config"])) self.tokenizer.load(str(paths["tokenizer"])) - if self.model is None: + if isinstance(self.model, nn.Identity): with disable_random_init(enable=disable_init): self.model = Transformer(self.config) diff --git a/scripts/demo/download.py b/scripts/demo/download.py index 8cb9052..ba7aaf3 100644 --- a/scripts/demo/download.py +++ b/scripts/demo/download.py @@ -1,7 +1,7 @@ from pathlib import Path from huggingface_hub import snapshot_download -PROJECT_ROOT = Path(__file__).parent.parent +PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") if __name__ == "__main__": diff --git a/scripts/demo/generate_ar.py b/scripts/demo/generate_ar.py index bf9959c..d1dfbaf 100644 --- a/scripts/demo/generate_ar.py +++ b/scripts/demo/generate_ar.py @@ -3,7 +3,7 @@ from pathlib import Path from astrai.config.param_config import ModelParameter from astrai.inference.generator import GeneratorFactory, GenerationRequest -PROJECT_ROOT = Path(__file__).parent.parent +PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") diff --git a/scripts/demo/generate_batch.py b/scripts/demo/generate_batch.py index fff99f8..d813341 100644 --- a/scripts/demo/generate_batch.py +++ b/scripts/demo/generate_batch.py @@ -3,7 +3,7 @@ from pathlib import Path from astrai.config.param_config import ModelParameter from astrai.inference.generator import GeneratorFactory, GenerationRequest -PROJECT_ROOT = Path(__file__).parent.parent +PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") diff --git a/scripts/demo/stream_chat.py b/scripts/demo/stream_chat.py index f89ce72..937823f 100644 --- a/scripts/demo/stream_chat.py +++ b/scripts/demo/stream_chat.py @@ -3,7 +3,7 @@ from pathlib import Path from astrai.config.param_config import ModelParameter from astrai.inference.generator import GeneratorFactory, GenerationRequest -PROJECT_ROOT = Path(__file__).parent.parent +PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") @@ -25,6 +25,7 @@ def chat(): max_len=param.config.max_len, history=history, system_prompt=None, + stream=True, ) generator = GeneratorFactory.create(param, request)