diff --git a/astrai/config/param_config.py b/astrai/config/param_config.py index 1a3f79d..8d6f33d 100644 --- a/astrai/config/param_config.py +++ b/astrai/config/param_config.py @@ -1,6 +1,7 @@ import torch.nn as nn import safetensors.torch as st +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Optional, Self, Union from pathlib import Path @@ -10,12 +11,38 @@ from astrai.config.model_config import ModelConfig from astrai.model.transformer import Transformer +@contextmanager +def disable_random_init(enable: bool = True): + init_functions = [ + "xavier_normal_", + "xavier_uniform_", + "kaiming_normal_", + "kaiming_uniform_", + "zeros_", + "ones_", + "constant_", + "normal_", + "uniform_", + ] + original_funcs = {} + for name in init_functions: + if enable and hasattr(nn.init, name): + original_funcs[name] = getattr(nn.init, name) + setattr(nn.init, name, lambda *args, **kwargs: None) + try: + yield + finally: + if enable: + for name, orig_func in original_funcs.items(): + setattr(nn.init, name, orig_func) + + @dataclass class BaseModelIO: """Base class for model I/O operations.""" - model: Optional[nn.Module] = field( - default=None, metadata={"help": "Transformer model."} + model: nn.Module = field( + default_factory=nn.Identity, metadata={"help": "Transformer model."} ) tokenizer: BpeTokenizer = field( default_factory=BpeTokenizer, metadata={"help": "Tokenizer for the model."} @@ -41,10 +68,13 @@ class BaseModelIO: if self.model is not None: st.save_file(self.model.state_dict(), str(paths["model"])) + self.config.save(str(paths["config"])) self.tokenizer.save(str(paths["tokenizer"])) - def load_components(self, load_dir: Union[str, Path]) -> Self: + def load_components( + self, load_dir: Union[str, Path], disable_init: bool = False + ) -> Self: """Load core model components.""" paths = self._get_file_paths(load_dir) @@ -52,7 +82,8 @@ class BaseModelIO: self.tokenizer.load(str(paths["tokenizer"])) if self.model is None: - self.model = Transformer(self.config) + with disable_random_init(enable=disable_init): + self.model = Transformer(self.config) if paths["model"].exists(): state_dict = st.load_file(str(paths["model"])) @@ -76,6 +107,8 @@ class ModelParameter(BaseModelIO): instance.save_components(save_dir) @classmethod - def load(cls, load_dir: Union[str, Path]) -> "ModelParameter": + def load( + cls, load_dir: Union[str, Path], disable_init: bool = False + ) -> "ModelParameter": instance = cls() - return instance.load_components(load_dir) + return instance.load_components(load_dir, disable_init=disable_init) diff --git a/astrai/inference/__init__.py b/astrai/inference/__init__.py index 28e2174..a63d77e 100644 --- a/astrai/inference/__init__.py +++ b/astrai/inference/__init__.py @@ -1,5 +1,4 @@ from astrai.inference.core import ( - disable_random_init, GeneratorCore, EmbeddingEncoderCore, KVCacheManager, @@ -15,7 +14,6 @@ from astrai.inference.generator import ( ) __all__ = [ - "disable_random_init", "GeneratorCore", "EmbeddingEncoderCore", "KVCacheManager", diff --git a/astrai/inference/core.py b/astrai/inference/core.py index db3a433..185b935 100644 --- a/astrai/inference/core.py +++ b/astrai/inference/core.py @@ -1,8 +1,6 @@ import torch -import torch.nn as nn from torch import Tensor -from contextlib import contextmanager from typing import Any, Callable, List, Tuple, Union, Optional, Self from astrai.config import ModelParameter, ModelConfig @@ -55,31 +53,6 @@ def apply_sampling_strategies( return logits -@contextmanager -def disable_random_init(): - init_functions = [ - "xavier_normal_", - "xavier_uniform_", - "kaiming_normal_", - "kaiming_uniform_", - "zeros_", - "ones_", - "constant_", - "normal_", - "uniform_", - ] - original_funcs = {} - for name in init_functions: - if hasattr(nn.init, name): - original_funcs[name] = getattr(nn.init, name) - setattr(nn.init, name, lambda *args, **kwargs: None) - try: - yield - finally: - for name, orig_func in original_funcs.items(): - setattr(nn.init, name, orig_func) - - class GeneratorCore: def __init__(self, parameter: ModelParameter): self.model = parameter.model diff --git a/scripts/demo/download.py b/scripts/demo/download.py index fb3150f..8cb9052 100644 --- a/scripts/demo/download.py +++ b/scripts/demo/download.py @@ -6,7 +6,7 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params") if __name__ == "__main__": snapshot_download( - repo_id="ViperEk/AstrAI", + repo_id="ViperEk/KHAOSZ", local_dir=PARAMETER_ROOT, force_download=True, ) diff --git a/scripts/demo/generate_ar.py b/scripts/demo/generate_ar.py index c75a601..bf9959c 100644 --- a/scripts/demo/generate_ar.py +++ b/scripts/demo/generate_ar.py @@ -1,7 +1,6 @@ import torch from pathlib import Path from astrai.config.param_config import ModelParameter -from astrai.inference.core import disable_random_init from astrai.inference.generator import GeneratorFactory, GenerationRequest PROJECT_ROOT = Path(__file__).parent.parent @@ -9,10 +8,8 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params") def generate_text(): - - with disable_random_init(): - param = ModelParameter.load(PARAMETER_ROOT) - param.to(device="cuda", dtype=torch.bfloat16) + param = ModelParameter.load(PARAMETER_ROOT, disable_init=True) + param.to(device="cuda", dtype=torch.bfloat16) query = input(">> ") diff --git a/scripts/demo/generate_batch.py b/scripts/demo/generate_batch.py index 4edd88e..fff99f8 100644 --- a/scripts/demo/generate_batch.py +++ b/scripts/demo/generate_batch.py @@ -1,7 +1,6 @@ import torch from pathlib import Path from astrai.config.param_config import ModelParameter -from astrai.inference.core import disable_random_init from astrai.inference.generator import GeneratorFactory, GenerationRequest PROJECT_ROOT = Path(__file__).parent.parent @@ -9,10 +8,8 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params") def batch_generate(): - - with disable_random_init(): - param = ModelParameter.load(PARAMETER_ROOT) - param.to(device="cuda", dtype=torch.bfloat16) + param = ModelParameter.load(PARAMETER_ROOT, disable_init=True) + param.to(device="cuda", dtype=torch.bfloat16) inputs = [ "你好", diff --git a/scripts/demo/stream_chat.py b/scripts/demo/stream_chat.py index fbe31ea..f89ce72 100644 --- a/scripts/demo/stream_chat.py +++ b/scripts/demo/stream_chat.py @@ -1,7 +1,6 @@ import torch from pathlib import Path from astrai.config.param_config import ModelParameter -from astrai.inference.core import disable_random_init from astrai.inference.generator import GeneratorFactory, GenerationRequest PROJECT_ROOT = Path(__file__).parent.parent @@ -9,10 +8,8 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params") def chat(): - - with disable_random_init(): - param = ModelParameter.load(PARAMETER_ROOT) - param.to(device="cuda", dtype=torch.bfloat16) + param = ModelParameter.load(PARAMETER_ROOT, disable_init=True) + param.to(device="cuda", dtype=torch.bfloat16) history = [] while True: diff --git a/scripts/tools/generate.py b/scripts/tools/generate.py index 2f78b2a..3d10489 100644 --- a/scripts/tools/generate.py +++ b/scripts/tools/generate.py @@ -4,23 +4,19 @@ import argparse from astrai.config.param_config import ModelParameter from astrai.inference.generator import BatchGenerator, GenerationRequest -from astrai.inference.core import disable_random_init def processor( model_dir: str, input_json_file: str, output_json_file: str, - batch_size: int, temperature: float, top_k: int, top_p: float, question_key: str, response_key: str, ): - with disable_random_init(): - param = ModelParameter.load(model_dir) - + param = ModelParameter.load(model_dir, disable_init=True) param.to(device="cuda", dtype=torch.bfloat16) generator = BatchGenerator(param) diff --git a/scripts/tools/perplexity.py b/scripts/tools/perplexity.py index e6a8de8..3fe02f4 100644 --- a/scripts/tools/perplexity.py +++ b/scripts/tools/perplexity.py @@ -7,7 +7,6 @@ import tqdm from torch import Tensor from astrai.config.param_config import ModelParameter -from astrai.inference.core import disable_random_init def compute_perplexity( @@ -42,9 +41,7 @@ def compute_perplexity( def process_file( model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str ): - with disable_random_init(): - param = ModelParameter.load(model_dir) - + param = ModelParameter.load(model_dir, disable_init=True) param.to(device="cuda", dtype=torch.bfloat16) model = param.model tokenizer = param.tokenizer