reafactor: 修改ModelParameter

This commit is contained in:
ViperEkura 2026-03-31 16:00:55 +08:00
parent 80c0b20877
commit 9f1561afe7
9 changed files with 48 additions and 60 deletions

View File

@ -1,6 +1,7 @@
import torch.nn as nn import torch.nn as nn
import safetensors.torch as st import safetensors.torch as st
from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Self, Union from typing import Optional, Self, Union
from pathlib import Path from pathlib import Path
@ -10,12 +11,38 @@ from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer from astrai.model.transformer import Transformer
@contextmanager
def disable_random_init(enable: bool = True):
init_functions = [
"xavier_normal_",
"xavier_uniform_",
"kaiming_normal_",
"kaiming_uniform_",
"zeros_",
"ones_",
"constant_",
"normal_",
"uniform_",
]
original_funcs = {}
for name in init_functions:
if enable and hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try:
yield
finally:
if enable:
for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
@dataclass @dataclass
class BaseModelIO: class BaseModelIO:
"""Base class for model I/O operations.""" """Base class for model I/O operations."""
model: Optional[nn.Module] = field( model: nn.Module = field(
default=None, metadata={"help": "Transformer model."} default_factory=nn.Identity, metadata={"help": "Transformer model."}
) )
tokenizer: BpeTokenizer = field( tokenizer: BpeTokenizer = field(
default_factory=BpeTokenizer, metadata={"help": "Tokenizer for the model."} default_factory=BpeTokenizer, metadata={"help": "Tokenizer for the model."}
@ -41,10 +68,13 @@ class BaseModelIO:
if self.model is not None: if self.model is not None:
st.save_file(self.model.state_dict(), str(paths["model"])) st.save_file(self.model.state_dict(), str(paths["model"]))
self.config.save(str(paths["config"])) self.config.save(str(paths["config"]))
self.tokenizer.save(str(paths["tokenizer"])) self.tokenizer.save(str(paths["tokenizer"]))
def load_components(self, load_dir: Union[str, Path]) -> Self: def load_components(
self, load_dir: Union[str, Path], disable_init: bool = False
) -> Self:
"""Load core model components.""" """Load core model components."""
paths = self._get_file_paths(load_dir) paths = self._get_file_paths(load_dir)
@ -52,7 +82,8 @@ class BaseModelIO:
self.tokenizer.load(str(paths["tokenizer"])) self.tokenizer.load(str(paths["tokenizer"]))
if self.model is None: if self.model is None:
self.model = Transformer(self.config) with disable_random_init(enable=disable_init):
self.model = Transformer(self.config)
if paths["model"].exists(): if paths["model"].exists():
state_dict = st.load_file(str(paths["model"])) state_dict = st.load_file(str(paths["model"]))
@ -76,6 +107,8 @@ class ModelParameter(BaseModelIO):
instance.save_components(save_dir) instance.save_components(save_dir)
@classmethod @classmethod
def load(cls, load_dir: Union[str, Path]) -> "ModelParameter": def load(
cls, load_dir: Union[str, Path], disable_init: bool = False
) -> "ModelParameter":
instance = cls() instance = cls()
return instance.load_components(load_dir) return instance.load_components(load_dir, disable_init=disable_init)

View File

@ -1,5 +1,4 @@
from astrai.inference.core import ( from astrai.inference.core import (
disable_random_init,
GeneratorCore, GeneratorCore,
EmbeddingEncoderCore, EmbeddingEncoderCore,
KVCacheManager, KVCacheManager,
@ -15,7 +14,6 @@ from astrai.inference.generator import (
) )
__all__ = [ __all__ = [
"disable_random_init",
"GeneratorCore", "GeneratorCore",
"EmbeddingEncoderCore", "EmbeddingEncoderCore",
"KVCacheManager", "KVCacheManager",

View File

@ -1,8 +1,6 @@
import torch import torch
import torch.nn as nn
from torch import Tensor from torch import Tensor
from contextlib import contextmanager
from typing import Any, Callable, List, Tuple, Union, Optional, Self from typing import Any, Callable, List, Tuple, Union, Optional, Self
from astrai.config import ModelParameter, ModelConfig from astrai.config import ModelParameter, ModelConfig
@ -55,31 +53,6 @@ def apply_sampling_strategies(
return logits return logits
@contextmanager
def disable_random_init():
init_functions = [
"xavier_normal_",
"xavier_uniform_",
"kaiming_normal_",
"kaiming_uniform_",
"zeros_",
"ones_",
"constant_",
"normal_",
"uniform_",
]
original_funcs = {}
for name in init_functions:
if hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try:
yield
finally:
for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
class GeneratorCore: class GeneratorCore:
def __init__(self, parameter: ModelParameter): def __init__(self, parameter: ModelParameter):
self.model = parameter.model self.model = parameter.model

View File

@ -6,7 +6,7 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
if __name__ == "__main__": if __name__ == "__main__":
snapshot_download( snapshot_download(
repo_id="ViperEk/AstrAI", repo_id="ViperEk/KHAOSZ",
local_dir=PARAMETER_ROOT, local_dir=PARAMETER_ROOT,
force_download=True, force_download=True,
) )

View File

@ -1,7 +1,6 @@
import torch import torch
from pathlib import Path from pathlib import Path
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from astrai.inference.core import disable_random_init
from astrai.inference.generator import GeneratorFactory, GenerationRequest from astrai.inference.generator import GeneratorFactory, GenerationRequest
PROJECT_ROOT = Path(__file__).parent.parent PROJECT_ROOT = Path(__file__).parent.parent
@ -9,10 +8,8 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def generate_text(): def generate_text():
param = ModelParameter.load(PARAMETER_ROOT, disable_init=True)
with disable_random_init(): param.to(device="cuda", dtype=torch.bfloat16)
param = ModelParameter.load(PARAMETER_ROOT)
param.to(device="cuda", dtype=torch.bfloat16)
query = input(">> ") query = input(">> ")

View File

@ -1,7 +1,6 @@
import torch import torch
from pathlib import Path from pathlib import Path
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from astrai.inference.core import disable_random_init
from astrai.inference.generator import GeneratorFactory, GenerationRequest from astrai.inference.generator import GeneratorFactory, GenerationRequest
PROJECT_ROOT = Path(__file__).parent.parent PROJECT_ROOT = Path(__file__).parent.parent
@ -9,10 +8,8 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def batch_generate(): def batch_generate():
param = ModelParameter.load(PARAMETER_ROOT, disable_init=True)
with disable_random_init(): param.to(device="cuda", dtype=torch.bfloat16)
param = ModelParameter.load(PARAMETER_ROOT)
param.to(device="cuda", dtype=torch.bfloat16)
inputs = [ inputs = [
"你好", "你好",

View File

@ -1,7 +1,6 @@
import torch import torch
from pathlib import Path from pathlib import Path
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from astrai.inference.core import disable_random_init
from astrai.inference.generator import GeneratorFactory, GenerationRequest from astrai.inference.generator import GeneratorFactory, GenerationRequest
PROJECT_ROOT = Path(__file__).parent.parent PROJECT_ROOT = Path(__file__).parent.parent
@ -9,10 +8,8 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def chat(): def chat():
param = ModelParameter.load(PARAMETER_ROOT, disable_init=True)
with disable_random_init(): param.to(device="cuda", dtype=torch.bfloat16)
param = ModelParameter.load(PARAMETER_ROOT)
param.to(device="cuda", dtype=torch.bfloat16)
history = [] history = []
while True: while True:

View File

@ -4,23 +4,19 @@ import argparse
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from astrai.inference.generator import BatchGenerator, GenerationRequest from astrai.inference.generator import BatchGenerator, GenerationRequest
from astrai.inference.core import disable_random_init
def processor( def processor(
model_dir: str, model_dir: str,
input_json_file: str, input_json_file: str,
output_json_file: str, output_json_file: str,
batch_size: int,
temperature: float, temperature: float,
top_k: int, top_k: int,
top_p: float, top_p: float,
question_key: str, question_key: str,
response_key: str, response_key: str,
): ):
with disable_random_init(): param = ModelParameter.load(model_dir, disable_init=True)
param = ModelParameter.load(model_dir)
param.to(device="cuda", dtype=torch.bfloat16) param.to(device="cuda", dtype=torch.bfloat16)
generator = BatchGenerator(param) generator = BatchGenerator(param)

View File

@ -7,7 +7,6 @@ import tqdm
from torch import Tensor from torch import Tensor
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from astrai.inference.core import disable_random_init
def compute_perplexity( def compute_perplexity(
@ -42,9 +41,7 @@ def compute_perplexity(
def process_file( def process_file(
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
): ):
with disable_random_init(): param = ModelParameter.load(model_dir, disable_init=True)
param = ModelParameter.load(model_dir)
param.to(device="cuda", dtype=torch.bfloat16) param.to(device="cuda", dtype=torch.bfloat16)
model = param.model model = param.model
tokenizer = param.tokenizer tokenizer = param.tokenizer