chore: 更新脚本并且修改gitignore

This commit is contained in:
ViperEkura 2026-04-02 15:40:31 +08:00
parent 475de51c7d
commit 912d7c7f54
6 changed files with 17 additions and 12 deletions

18
.gitignore vendored
View File

@ -6,11 +6,15 @@
# Allow specific file types and root files # Allow specific file types and root files
!*.py !*.py
!*.md
!*.png
!LICENSE # Allow GitHub files
!pyproject.toml !/.github/ISSUE_TEMPLATE/*
!.github/ISSUE_TEMPLATE/* !/.github/workflows/lint.yml
!.github/workflows/lint.yml !/.github/workflows/tests.yml
!.github/workflows/tests.yml
# Allow root files
!/assets/*
!/CONTRIBUTING.md
!/LICENSE
!/pyproject.toml
!/README.md

View File

@ -81,7 +81,7 @@ class BaseModelIO:
self.config.load(str(paths["config"])) self.config.load(str(paths["config"]))
self.tokenizer.load(str(paths["tokenizer"])) self.tokenizer.load(str(paths["tokenizer"]))
if self.model is None: if isinstance(self.model, nn.Identity):
with disable_random_init(enable=disable_init): with disable_random_init(enable=disable_init):
self.model = Transformer(self.config) self.model = Transformer(self.config)

View File

@ -1,7 +1,7 @@
from pathlib import Path from pathlib import Path
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
PROJECT_ROOT = Path(__file__).parent.parent PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params") PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -3,7 +3,7 @@ from pathlib import Path
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GeneratorFactory, GenerationRequest from astrai.inference.generator import GeneratorFactory, GenerationRequest
PROJECT_ROOT = Path(__file__).parent.parent PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params") PARAMETER_ROOT = Path(PROJECT_ROOT, "params")

View File

@ -3,7 +3,7 @@ from pathlib import Path
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GeneratorFactory, GenerationRequest from astrai.inference.generator import GeneratorFactory, GenerationRequest
PROJECT_ROOT = Path(__file__).parent.parent PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params") PARAMETER_ROOT = Path(PROJECT_ROOT, "params")

View File

@ -3,7 +3,7 @@ from pathlib import Path
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GeneratorFactory, GenerationRequest from astrai.inference.generator import GeneratorFactory, GenerationRequest
PROJECT_ROOT = Path(__file__).parent.parent PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params") PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
@ -25,6 +25,7 @@ def chat():
max_len=param.config.max_len, max_len=param.config.max_len,
history=history, history=history,
system_prompt=None, system_prompt=None,
stream=True,
) )
generator = GeneratorFactory.create(param, request) generator = GeneratorFactory.create(param, request)