chore: 更新脚本并且修改gitignore

This commit is contained in:
ViperEkura 2026-04-02 15:40:31 +08:00
parent 475de51c7d
commit 912d7c7f54
6 changed files with 17 additions and 12 deletions

18
.gitignore vendored
View File

@ -6,11 +6,15 @@
# Allow specific file types and root files
!*.py
!*.md
!*.png
!LICENSE
!pyproject.toml
!.github/ISSUE_TEMPLATE/*
!.github/workflows/lint.yml
!.github/workflows/tests.yml
# Allow GitHub files
!/.github/ISSUE_TEMPLATE/*
!/.github/workflows/lint.yml
!/.github/workflows/tests.yml
# Allow root files
!/assets/*
!/CONTRIBUTING.md
!/LICENSE
!/pyproject.toml
!/README.md

View File

@ -81,7 +81,7 @@ class BaseModelIO:
self.config.load(str(paths["config"]))
self.tokenizer.load(str(paths["tokenizer"]))
if self.model is None:
if isinstance(self.model, nn.Identity):
with disable_random_init(enable=disable_init):
self.model = Transformer(self.config)

View File

@ -1,7 +1,7 @@
from pathlib import Path
from huggingface_hub import snapshot_download
PROJECT_ROOT = Path(__file__).parent.parent
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
if __name__ == "__main__":

View File

@ -3,7 +3,7 @@ from pathlib import Path
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GeneratorFactory, GenerationRequest
PROJECT_ROOT = Path(__file__).parent.parent
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")

View File

@ -3,7 +3,7 @@ from pathlib import Path
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GeneratorFactory, GenerationRequest
PROJECT_ROOT = Path(__file__).parent.parent
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")

View File

@ -3,7 +3,7 @@ from pathlib import Path
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import GeneratorFactory, GenerationRequest
PROJECT_ROOT = Path(__file__).parent.parent
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
@ -25,6 +25,7 @@ def chat():
max_len=param.config.max_len,
history=history,
system_prompt=None,
stream=True,
)
generator = GeneratorFactory.create(param, request)