diff --git a/.gitignore b/.gitignore index f8f1d17..def3a3f 100644 --- a/.gitignore +++ b/.gitignore @@ -6,11 +6,15 @@ # Allow specific file types and root files !*.py -!*.md -!*.png -!LICENSE -!pyproject.toml -!.github/ISSUE_TEMPLATE/* -!.github/workflows/lint.yml -!.github/workflows/tests.yml \ No newline at end of file +# Allow GitHub files +!/.github/ISSUE_TEMPLATE/* +!/.github/workflows/lint.yml +!/.github/workflows/tests.yml + +# Allow root files +!/assets/* +!/CONTRIBUTING.md +!/LICENSE +!/pyproject.toml +!/README.md \ No newline at end of file diff --git a/astrai/config/param_config.py b/astrai/config/param_config.py index 8d6f33d..c5d7076 100644 --- a/astrai/config/param_config.py +++ b/astrai/config/param_config.py @@ -81,7 +81,7 @@ class BaseModelIO: self.config.load(str(paths["config"])) self.tokenizer.load(str(paths["tokenizer"])) - if self.model is None: + if isinstance(self.model, nn.Identity): with disable_random_init(enable=disable_init): self.model = Transformer(self.config) diff --git a/scripts/demo/download.py b/scripts/demo/download.py index 8cb9052..ba7aaf3 100644 --- a/scripts/demo/download.py +++ b/scripts/demo/download.py @@ -1,7 +1,7 @@ from pathlib import Path from huggingface_hub import snapshot_download -PROJECT_ROOT = Path(__file__).parent.parent +PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") if __name__ == "__main__": diff --git a/scripts/demo/generate_ar.py b/scripts/demo/generate_ar.py index bf9959c..d1dfbaf 100644 --- a/scripts/demo/generate_ar.py +++ b/scripts/demo/generate_ar.py @@ -3,7 +3,7 @@ from pathlib import Path from astrai.config.param_config import ModelParameter from astrai.inference.generator import GeneratorFactory, GenerationRequest -PROJECT_ROOT = Path(__file__).parent.parent +PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") diff --git a/scripts/demo/generate_batch.py b/scripts/demo/generate_batch.py index fff99f8..d813341 100644 --- a/scripts/demo/generate_batch.py +++ b/scripts/demo/generate_batch.py @@ -3,7 +3,7 @@ from pathlib import Path from astrai.config.param_config import ModelParameter from astrai.inference.generator import GeneratorFactory, GenerationRequest -PROJECT_ROOT = Path(__file__).parent.parent +PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") diff --git a/scripts/demo/stream_chat.py b/scripts/demo/stream_chat.py index f89ce72..937823f 100644 --- a/scripts/demo/stream_chat.py +++ b/scripts/demo/stream_chat.py @@ -3,7 +3,7 @@ from pathlib import Path from astrai.config.param_config import ModelParameter from astrai.inference.generator import GeneratorFactory, GenerationRequest -PROJECT_ROOT = Path(__file__).parent.parent +PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") @@ -25,6 +25,7 @@ def chat(): max_len=param.config.max_len, history=history, system_prompt=None, + stream=True, ) generator = GeneratorFactory.create(param, request)