chore: 更新脚本并且修改gitignore
This commit is contained in:
parent
475de51c7d
commit
912d7c7f54
|
|
@ -6,11 +6,15 @@
|
|||
|
||||
# Allow specific file types and root files
|
||||
!*.py
|
||||
!*.md
|
||||
!*.png
|
||||
|
||||
!LICENSE
|
||||
!pyproject.toml
|
||||
!.github/ISSUE_TEMPLATE/*
|
||||
!.github/workflows/lint.yml
|
||||
!.github/workflows/tests.yml
|
||||
# Allow GitHub files
|
||||
!/.github/ISSUE_TEMPLATE/*
|
||||
!/.github/workflows/lint.yml
|
||||
!/.github/workflows/tests.yml
|
||||
|
||||
# Allow root files
|
||||
!/assets/*
|
||||
!/CONTRIBUTING.md
|
||||
!/LICENSE
|
||||
!/pyproject.toml
|
||||
!/README.md
|
||||
|
|
@ -81,7 +81,7 @@ class BaseModelIO:
|
|||
self.config.load(str(paths["config"]))
|
||||
self.tokenizer.load(str(paths["tokenizer"]))
|
||||
|
||||
if self.model is None:
|
||||
if isinstance(self.model, nn.Identity):
|
||||
with disable_random_init(enable=disable_init):
|
||||
self.model = Transformer(self.config)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from pathlib import Path
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
|||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
|||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
|||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||
|
||||
|
||||
|
|
@ -25,6 +25,7 @@ def chat():
|
|||
max_len=param.config.max_len,
|
||||
history=history,
|
||||
system_prompt=None,
|
||||
stream=True,
|
||||
)
|
||||
generator = GeneratorFactory.create(param, request)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue