chore: 更新脚本并且修改gitignore
This commit is contained in:
parent
475de51c7d
commit
912d7c7f54
|
|
@ -6,11 +6,15 @@
|
||||||
|
|
||||||
# Allow specific file types and root files
|
# Allow specific file types and root files
|
||||||
!*.py
|
!*.py
|
||||||
!*.md
|
|
||||||
!*.png
|
|
||||||
|
|
||||||
!LICENSE
|
# Allow GitHub files
|
||||||
!pyproject.toml
|
!/.github/ISSUE_TEMPLATE/*
|
||||||
!.github/ISSUE_TEMPLATE/*
|
!/.github/workflows/lint.yml
|
||||||
!.github/workflows/lint.yml
|
!/.github/workflows/tests.yml
|
||||||
!.github/workflows/tests.yml
|
|
||||||
|
# Allow root files
|
||||||
|
!/assets/*
|
||||||
|
!/CONTRIBUTING.md
|
||||||
|
!/LICENSE
|
||||||
|
!/pyproject.toml
|
||||||
|
!/README.md
|
||||||
|
|
@ -81,7 +81,7 @@ class BaseModelIO:
|
||||||
self.config.load(str(paths["config"]))
|
self.config.load(str(paths["config"]))
|
||||||
self.tokenizer.load(str(paths["tokenizer"]))
|
self.tokenizer.load(str(paths["tokenizer"]))
|
||||||
|
|
||||||
if self.model is None:
|
if isinstance(self.model, nn.Identity):
|
||||||
with disable_random_init(enable=disable_init):
|
with disable_random_init(enable=disable_init):
|
||||||
self.model = Transformer(self.config)
|
self.model = Transformer(self.config)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).parent.parent
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).parent.parent
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).parent.parent
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).parent.parent
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -25,6 +25,7 @@ def chat():
|
||||||
max_len=param.config.max_len,
|
max_len=param.config.max_len,
|
||||||
history=history,
|
history=history,
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
|
stream=True,
|
||||||
)
|
)
|
||||||
generator = GeneratorFactory.create(param, request)
|
generator = GeneratorFactory.create(param, request)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue