feat: 实现模型动态注册机制

This commit is contained in:
ViperEkura 2026-04-05 19:38:12 +08:00
parent ff43a2fab8
commit fc278d17ab
25 changed files with 686 additions and 651 deletions

1
.gitattributes vendored
View File

@ -3,6 +3,7 @@
# Files that MUST use LF (Unix/Linux execution)
*.sh text eol=lf
*.py text eol=lf
Dockerfile text eol=lf
.dockerignore text eol=lf

View File

@ -90,8 +90,12 @@ from astrai.inference import InferenceEngine, GenerationRequest
param = ModelParameter.load("your_model_dir")
param.to(device="cuda", dtype=torch.bfloat16)
# Create engine
engine = InferenceEngine(param)
# Create engine with separate model and tokenizer
engine = InferenceEngine(
model=param.model,
tokenizer=param.tokenizer,
config=param.config,
)
# Build request
request = GenerationRequest(

View File

@ -12,7 +12,7 @@ from astrai.inference import (
GenerationRequest,
InferenceEngine,
)
from astrai.model.transformer import Transformer
from astrai.model import AutoModel, Transformer
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
__all__ = [
@ -27,4 +27,5 @@ __all__ = [
"StrategyFactory",
"SchedulerFactory",
"BaseFactory",
"AutoModel",
]

View File

@ -1,11 +1,7 @@
from astrai.config.model_config import ModelConfig
from astrai.config.param_config import BaseModelIO, ModelParameter
from astrai.config.train_config import TrainConfig
__all__ = [
# Base I/O
"BaseModelIO",
"ModelParameter",
# Model configuration
"ModelConfig",
"TrainConfig",

View File

@ -6,6 +6,7 @@ from typing import Optional, Self
@dataclass
class ModelConfig:
# basic config
model_type: Optional[str] = None
vocab_size: Optional[int] = None
dim: Optional[int] = None

View File

@ -1,114 +0,0 @@
from contextlib import contextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import Self, Union
import safetensors.torch as st
import torch.nn as nn
from astrai.config.model_config import ModelConfig
from astrai.tokenize import BpeTokenizer
from astrai.model.transformer import Transformer
@contextmanager
def disable_random_init(enable: bool = True):
init_functions = [
"xavier_normal_",
"xavier_uniform_",
"kaiming_normal_",
"kaiming_uniform_",
"zeros_",
"ones_",
"constant_",
"normal_",
"uniform_",
]
original_funcs = {}
for name in init_functions:
if enable and hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try:
yield
finally:
if enable:
for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
@dataclass
class BaseModelIO:
"""Base class for model I/O operations."""
model: nn.Module = field(
default_factory=nn.Identity, metadata={"help": "Transformer model."}
)
tokenizer: BpeTokenizer = field(
default_factory=BpeTokenizer, metadata={"help": "Tokenizer for the model."}
)
config: ModelConfig = field(
default_factory=ModelConfig,
metadata={"help": "Transformer model configuration."},
)
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
"""Get standardized file paths for model components."""
dir_path = Path(directory)
return {
"model": dir_path / "model.safetensors",
"config": dir_path / "config.json",
"tokenizer": dir_path / "tokenizer.json",
}
def save_components(self, save_dir: Union[str, Path]):
"""Save core model components."""
paths = self._get_file_paths(save_dir)
paths["model"].parent.mkdir(parents=True, exist_ok=True)
if self.model is not None:
st.save_file(self.model.state_dict(), str(paths["model"]))
self.config.save(str(paths["config"]))
self.tokenizer.save(str(paths["tokenizer"]))
def load_components(
self, load_dir: Union[str, Path], disable_init: bool = False
) -> Self:
"""Load core model components."""
paths = self._get_file_paths(load_dir)
self.config.load(str(paths["config"]))
self.tokenizer.load(str(paths["tokenizer"]))
if isinstance(self.model, nn.Identity):
with disable_random_init(enable=disable_init):
self.model = Transformer(self.config)
if paths["model"].exists():
state_dict = st.load_file(str(paths["model"]))
self.model.load_state_dict(state_dict)
return self
def to(self, *args, **kwargs) -> "BaseModelIO":
"""Move model to device."""
if self.model is not None:
self.model.to(*args, **kwargs)
return self
@dataclass
class ModelParameter(BaseModelIO):
"""Container for model parameters with serialization capabilities."""
@classmethod
def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]):
instance.save_components(save_dir)
@classmethod
def load(
cls, load_dir: Union[str, Path], disable_init: bool = False
) -> "ModelParameter":
instance = cls()
return instance.load_components(load_dir, disable_init=disable_init)

View File

@ -185,3 +185,6 @@ class BaseFactory(ABC, Generic[T]):
def list_by_priority(cls, reverse: bool = False) -> List[str]:
"""List registered component names sorted by priority."""
return cls._registry.list_by_priority(reverse)
__all__ = ["Registry", "BaseFactory"]

View File

@ -1,11 +1,11 @@
"""Unified inference engine."""
import threading
import torch
import torch.nn as nn
from typing import Any, Dict, Generator, List, Optional, Union
from astrai.config import ModelParameter
from astrai.tokenize.chat_template import build_prompt
from astrai.tokenize.tokenizer import TextTokenizer
from astrai.inference.scheduler import InferenceScheduler
@ -14,22 +14,18 @@ class GenerationRequest:
def __init__(
self,
query: Union[str, List[str]],
messages: List[Dict[str, str]],
top_k: int = 50,
top_p: float = 1.0,
temperature: float = 1.0,
max_len: int = 1024,
history: Optional[Any] = None,
system_prompt: Optional[str] = None,
stream: bool = False,
):
self.query = query
self.messages = messages
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.max_len = max_len
self.history = history
self.system_prompt = system_prompt
self.stream = stream
self._validate()
@ -107,26 +103,41 @@ class InferenceEngine:
def __init__(
self,
parameter: ModelParameter,
max_batch_size: int = 16,
model: nn.Module,
tokenizer: TextTokenizer,
max_batch_size: int = 1,
max_seq_len: Optional[int] = None,
):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
"""
Initialize inference engine with separate model and tokenizer.
model_params = next(self.model.parameters())
self.device = model_params.device
self.dtype = model_params.dtype
Args:
model: The language model for inference (nn.Module, e.g., Transformer)
tokenizer: The tokenizer for encoding/decoding text
config: Model configuration
max_batch_size: Maximum batch size for continuous batching
max_seq_len: Maximum sequence length (defaults to config.max_len)
"""
self.model = model
self.tokenizer = tokenizer
# Get device and dtype from model parameters
try:
first_param = next(model.parameters())
device = first_param.device
dtype = first_param.dtype
except StopIteration:
# Model has no parameters, use default device/dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
self.scheduler = InferenceScheduler(
model=self.model,
tokenizer=self.tokenizer,
config=self.config,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
device=self.device,
dtype=self.dtype,
device=device,
dtype=dtype,
)
self.kv_cache = self.scheduler.kv_cache
@ -160,7 +171,8 @@ class InferenceEngine:
self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]:
"""Generate with GenerationRequest object."""
prompt = build_prompt(request.query, request.history)
# Use tokenizer's chat template with messages
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate(
prompt=prompt,

View File

@ -8,7 +8,8 @@ from typing import Any, Callable, Dict, List, Optional
import torch
from torch import Tensor
from astrai.config import ModelConfig
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import TextTokenizer
class TaskStatus:
@ -98,23 +99,23 @@ class InferenceScheduler:
def __init__(
self,
model,
tokenizer,
config: ModelConfig,
model: AutoModel,
tokenizer: TextTokenizer,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
config = model.config
self.model = model
self.tokenizer = tokenizer
self.config = config
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len
self.device = device
self.dtype = dtype
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
num_heads = config.n_kv_heads
num_kv_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads
n_layers = config.n_layers
@ -123,26 +124,26 @@ class InferenceScheduler:
max_batch_size,
self.max_seq_len,
n_layers,
num_heads,
num_kv_heads,
head_dim,
),
device=device,
dtype=dtype,
device=self.device,
dtype=self.dtype,
)
v_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_heads,
num_kv_heads,
head_dim,
),
device=device,
dtype=dtype,
device=self.device,
dtype=self.dtype,
)
self.kv_cache = (k_cache, v_cache)
self.seq_mask = torch.ones(
(max_batch_size, self.max_seq_len), device=device, dtype=torch.bool
(max_batch_size, self.max_seq_len), device=self.device, dtype=torch.bool
)
self.waiting_queue: List[Task] = []
@ -259,7 +260,7 @@ class InferenceScheduler:
)
with torch.inference_mode():
outputs = self.model(
self.model(
input_ids,
input_mask=input_mask,
start_pos=0,

View File

@ -3,8 +3,6 @@ Inference Server with Continuous Batching Support
FastAPI server for inference with continuous batching.
Provides OpenAI-compatible chat completion endpoints.
Author: AstrAI Team
"""
import json
@ -19,14 +17,15 @@ from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from astrai.config.param_config import ModelParameter
from astrai.inference.engine import GenerationRequest, InferenceEngine
from astrai.inference.engine import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import TextTokenizer
logger = logging.getLogger(__name__)
# Global model parameter and engine (loaded once)
_model_param: Optional[ModelParameter] = None
_engine: Optional[InferenceEngine] = None
_model_param: Optional[Any] = None
_project_root = Path(__file__).parent.parent.parent
# Server configuration (set before running server)
@ -95,13 +94,17 @@ def load_model(
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
_model_param = ModelParameter.load(param_path, disable_init=True)
# Load tokenizer separately
tokenizer = TextTokenizer.from_pretrained(param_path)
_model_param = AutoModel.from_pretrained(param_path, tokenizer=tokenizer)
_model_param.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
# Initialize inference engine with continuous batching
# Initialize inference engine with separate model and tokenizer
_engine = InferenceEngine(
parameter=_model_param,
model=_model_param,
tokenizer=tokenizer,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
@ -164,27 +167,43 @@ def convert_messages_to_history(
return system_prompt, history if history else None
def convert_messages_to_prompt(messages: List[ChatMessage]) -> str:
def convert_messages_to_prompt(
messages: List[ChatMessage], engine: InferenceEngine = None
) -> str:
"""Convert messages to prompt string.
Args:
messages: List of ChatMessage objects
engine: InferenceEngine instance for accessing tokenizer
Returns:
str: Formatted prompt string
"""
system_prompt, history = convert_messages_to_history(messages)
# Convert to dict format for chat template
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
# Get the last user message as query
user_messages = [m.content for m in messages if m.role == "user"]
if not user_messages:
raise ValueError("No user message found")
query = user_messages[-1]
# Extract system prompt if present
system_prompt = None
filtered_messages = []
for msg in msg_dicts:
if msg["role"] == "system":
system_prompt = msg["content"]
else:
filtered_messages.append(msg)
# Build prompt using chat template
from astrai.tokenize.chat_template import build_prompt
# Use engine's tokenizer chat template if available
if engine is not None and engine.tokenizer is not None:
return engine.tokenizer.apply_chat_template(
filtered_messages, system_prompt=system_prompt, tokenize=False
)
return build_prompt(query, history)
# Fallback: simple concatenation (deprecated)
prompt_parts = []
for msg in filtered_messages:
prompt_parts.append(
f"<im▁start>{msg['role']}\n{msg['content']}<im▁end>"
)
return "\n".join(prompt_parts) + "\n<im▁start>assistant\n"
@app.get("/health")
@ -213,8 +232,8 @@ async def chat_completion(request: ChatCompletionRequest):
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Convert messages to prompt
prompt = convert_messages_to_prompt(request.messages)
# Convert messages to prompt using engine's tokenizer
prompt = convert_messages_to_prompt(request.messages, engine=_engine)
if request.stream:
# Streaming response (use synchronous generator)
@ -294,15 +313,18 @@ async def generate(
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Convert history format
hist: Optional[List[Tuple[str, str]]] = None
# Build messages for chat template
messages = []
if history:
hist = [(h[0], h[1]) for h in history]
# Convert history format: List[List[str]] -> List[Dict]
for h in history:
if len(h) >= 2:
messages.append({"role": "user", "content": h[0]})
messages.append({"role": "assistant", "content": h[1]})
messages.append({"role": "user", "content": query})
# Build prompt
from astrai.tokenize.chat_template import build_prompt
prompt = build_prompt(query, hist)
# Use tokenizer's chat template
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
if stream:
# Synchronous streaming

View File

@ -6,5 +6,17 @@ from astrai.model.module import (
RMSNorm,
)
from astrai.model.transformer import Transformer
from astrai.model.automodel import AutoModel
__all__ = ["Linear", "RMSNorm", "MLP", "GQA", "DecoderBlock", "Transformer"]
__all__ = [
# Modules
"Linear",
"RMSNorm",
"MLP",
"GQA",
"DecoderBlock",
# Models
"Transformer",
"AutoModel",
]

134
astrai/model/automodel.py Normal file
View File

@ -0,0 +1,134 @@
"""
AutoModel base class for model loading and saving.
"""
import torch.nn as nn
import safetensors.torch as st
from pathlib import Path
from contextlib import contextmanager
from typing import Self, Union, Dict, Type
from astrai.config import ModelConfig
@contextmanager
def _disable_random_init(enable: bool = True):
init_functions = [
"xavier_normal_",
"xavier_uniform_",
"kaiming_normal_",
"kaiming_uniform_",
"zeros_",
"ones_",
"constant_",
"normal_",
"uniform_",
]
original_funcs = {}
for name in init_functions:
if enable and hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try:
yield
finally:
if enable:
for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
class AutoModel(nn.Module):
"""
Autoregressive language model base class.
Provides model loading/saving and generation capabilities.
"""
# Model registry - stored as class attribute
_registry: Dict[str, Type["AutoModel"]] = {}
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
@classmethod
def register(cls, model_type: str):
"""
Class method decorator to register model type.
Usage:
@AutoModel.register('transformer')
class Transformer(AutoModel):
...
"""
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
cls._registry[model_type.lower()] = sub_cls
return sub_cls
return decorator
@classmethod
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
"""Get model class by model_type string."""
model_type = model_type.lower()
if model_type not in cls._registry:
available = list(cls._registry.keys())
raise ValueError(
f"Unknown model_type: {model_type}. Available: {available}"
)
return cls._registry[model_type]
@classmethod
def from_pretrained(
cls,
path: Union[str, Path],
disable_random_init: bool = True,
) -> nn.Module:
model_path = Path(path)
# Load config
config = ModelConfig()
config_path = model_path / "config.json"
if config_path.exists():
config.load(str(config_path))
else:
raise FileNotFoundError(f"Config file not found: {config_path}")
# If called from base class, use model_type to determine actual model class
if cls is AutoModel:
model_type = config.model_type or "transformer"
actual_cls = cls.get_model_class(model_type)
else:
raise ValueError(
f"Cannot call from_pretrained() on subclass {cls.__name__}"
)
with _disable_random_init(enable=disable_random_init):
model = actual_cls(config)
# Load weights
weights_path = model_path / "model.safetensors"
if weights_path.exists():
state_dict = st.load_file(str(weights_path))
model.load_state_dict(state_dict, strict=False)
return model
def save_pretrained(
self,
save_directory: Union[str, Path],
) -> None:
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
# Save config
self.config.save(str(save_path / "config.json"))
# Save weights
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
def to(self, *args, **kwargs) -> Self:
"""Move model to device/dtype."""
return super().to(*args, **kwargs)

View File

@ -5,6 +5,7 @@ import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import ModelConfig
from astrai.model.automodel import AutoModel
from astrai.model.module import (
DecoderBlock,
Embedding,
@ -66,9 +67,14 @@ def process_attention_mask(
return attention_mask
class Transformer(nn.Module):
@AutoModel.register("transformer")
class Transformer(AutoModel):
"""
Transformer language model.
"""
def __init__(self, config: ModelConfig):
super().__init__()
super().__init__(config)
self.config = config
self.rotary_embeding = RotaryEmbedding(
config.dim // config.n_heads, config.max_len
@ -97,16 +103,27 @@ class Transformer(nn.Module):
if self.config.tie_weight:
self.lm_head.weight = self.embed_tokens.weight
self._init_parameters()
self._init_weights()
def _init_weights(self):
for param in self.parameters():
if param.dim() > 1:
nn.init.normal_(param, mean=0.0, std=0.006)
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
lm_head_key = "lm_head.weight"
embed_key = "embed_tokens.weight"
# Make a copy to avoid modifying the original state_dict
state_dict = dict(state_dict)
if self.config.tie_weight:
# same tensor
state_dict[lm_head_key] = state_dict[embed_key]
if embed_key in state_dict:
state_dict[lm_head_key] = state_dict[embed_key]
else:
# If lm_head.weight exists in checkpoint, use it directly
# If not, copy from embed_tokens.weight
if lm_head_key not in state_dict and embed_key in state_dict:
# use clone to avoid sharing the same tensor
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
@ -125,11 +142,6 @@ class Transformer(nn.Module):
return state_dict
def _init_parameters(self):
for param in self.parameters():
if param.dim() > 1:
nn.init.normal_(param, mean=0.0, std=0.006)
def forward(
self,
input_ids: Tensor,

View File

@ -1,22 +1,23 @@
from astrai.tokenize.tokenizer import (
BaseTokenizer,
TextTokenizer,
BpeTokenizer,
BaseTrainer,
BpeTrainer,
)
from astrai.tokenize.trainer import BpeTrainer
from astrai.tokenize.chat_template import (
ChatTemplate,
HistoryType,
MessageType,
build_prompt,
)
# Alias for compatibility
AutoTokenizer = TextTokenizer
__all__ = [
"BaseTokenizer",
"TextTokenizer",
"AutoTokenizer",
"BpeTokenizer",
"BaseTrainer",
"BpeTrainer",
"ChatTemplate",
"HistoryType",
"MessageType",
"CHAT_TEMPLATES",
"build_prompt",
]

View File

@ -1,7 +1,6 @@
from typing import Dict, List, Optional, Tuple, Any
from jinja2 import Template
from dataclasses import dataclass
from astrai.factory import Registry
HistoryType = List[Tuple[str, str]]
MessageType = Dict[str, str]
@ -75,213 +74,3 @@ class ChatTemplate:
jinja_template = Template(self.template_str)
return jinja_template.render(**variables)
# Global registry instance
_default_registry = Registry()
# Default template name
_default_template_name = "chatml"
# Convenience functions
def register_chat_template(
name: str,
template_str: str,
description: str = "",
default_variables: Optional[Dict[str, Any]] = None,
special_tokens: Optional[Dict[str, str]] = None,
) -> ChatTemplate:
"""Register a chat template in the global registry."""
template = ChatTemplate(
name=name,
template_str=template_str,
description=description,
default_variables=default_variables,
special_tokens=special_tokens,
)
_default_registry.register(name, template, category=None, priority=0)
return template
def set_default_chat_template(name: str) -> None:
"""Set the default chat template name globally."""
global _default_template_name
if not _default_registry.contains(name):
raise KeyError(
f"Chat template '{name}' not found. Available: {list(_default_registry.list_names())}"
)
_default_template_name = name
def get_default_chat_template_name() -> str:
"""Get the current default chat template name."""
return _default_template_name
def get_chat_template(name: str) -> ChatTemplate:
"""Get a chat template from the global registry."""
return _default_registry.get(name)
def list_chat_templates() -> List[str]:
"""List all registered chat template names."""
return _default_registry.list_names()
def chat_template_exists(name: str) -> bool:
"""Check if a chat template exists."""
return _default_registry.contains(name)
def build_prompt(
query: str,
system_prompt: Optional[str] = None,
history: Optional[HistoryType] = None,
template: Optional[str] = None,
template_name: Optional[str] = None,
**extra_variables: Any,
) -> str:
"""Build prompt using a registered chat template or a custom template string.
This function maintains backward compatibility with the previous API.
Args:
query: The current user query.
system_prompt: Optional system prompt.
history: Optional list of (user_msg, assistant_msg) pairs.
template: If provided, uses this exact Jinja2 template string (overrides template_name).
template_name: Name of a registered template to use (ignored if `template` is given).
If None, uses the globally set default template (see `set_default_chat_template`).
**extra_variables: Additional variables to pass to the template.
Returns:
Rendered prompt string.
Raises:
KeyError: If `template_name` is not registered.
"""
# Convert history to message format
messages: List[MessageType] = []
if history:
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": query})
if template is not None:
# Use the provided template string directly
jinja_template = Template(template)
variables = {"messages": messages, **extra_variables}
if system_prompt is not None:
variables["system_prompt"] = system_prompt
return jinja_template.render(**variables)
else:
# Determine which template name to use
if template_name is None:
template_name = _default_template_name
# Use a registered template
chat_template = get_chat_template(template_name)
return chat_template.render(
messages=messages,
system_prompt=system_prompt,
**extra_variables,
)
# Predefined templates
# ChatML template (original)
register_chat_template(
name="chatml",
template_str=(
"{%- if system_prompt -%}\n"
"{{ bos_token }}system\n"
"{{ system_prompt }}{{ eos_token }}\n"
"{%- endif -%}\n"
"{%- for message in messages -%}\n"
"{{ bos_token }}{{ message['role'] }}\n"
"{{ message['content'] }}{{ eos_token }}\n"
"{%- endfor -%}\n"
"{{ bos_token }}assistant\n"
),
description="ChatML format with configurable special tokens.",
special_tokens={"bos_token": "<im▁start>", "eos_token": "<im▁end>"},
)
# Simplified template without special tokens (plain text)
register_chat_template(
name="plain",
template_str=(
"{%- if system_prompt -%}\n"
"System: {{ system_prompt }}\n"
"{%- endif -%}\n"
"{%- for message in messages -%}\n"
"{{ message['role']|capitalize }}: {{ message['content'] }}\n"
"{%- endfor -%}\n"
"Assistant:"
),
description="Plain text format with role labels.",
)
# Alpaca-style template
register_chat_template(
name="alpaca",
template_str=(
"{%- if system_prompt -%}\n"
"### Instruction:\n"
"{{ system_prompt }}\n"
"{%- endif -%}\n"
"### Input:\n"
"{{ messages[-1]['content'] }}\n"
"### Response:"
),
description="Alpaca instructionresponse format (singleturn).",
default_variables={},
)
# OpenAI chat format (approximation)
register_chat_template(
name="openai",
template_str=(
"{%- if system_prompt -%}\n"
"{{ bos_token }}system\n"
"{{ system_prompt }}{{ eos_token }}\n"
"{%- endif -%}\n"
"{%- for message in messages -%}\n"
"{{ bos_token }}{{ message['role'] }}\n"
"{{ message['content'] }}{{ eos_token }}\n"
"{%- endfor -%}\n"
"{{ bos_token }}assistant\n"
),
description="OpenAIcompatible chat format with configurable special tokens.",
special_tokens={"bos_token": "<im▁start>", "eos_token": "<im▁end>"},
)
# Llama2 style with [INST] tags
register_chat_template(
name="llama2",
template_str=(
"{%- if system_prompt -%}\n"
"<<SYS>>\n"
"{{ system_prompt }}\n"
"<</SYS>>\n"
"{%- endif -%}\n"
"[INST] {{ messages[-1]['content'] }} [/INST]"
),
description="Llama2 style with [INST] tags (singleturn).",
default_variables={},
)
__all__ = [
"ChatTemplate",
"register_chat_template",
"get_chat_template",
"list_chat_templates",
"chat_template_exists",
"build_prompt",
"set_default_chat_template",
"get_default_chat_template_name",
"HistoryType",
"MessageType",
]

View File

@ -1,97 +1,254 @@
from abc import ABC, abstractmethod
from typing import List, Union
"""
Tokenizer module with implementation and auto-loading support.
"""
import json
from pathlib import Path
from typing import Dict, List, Optional, Union
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer as BpeTrainerImpl
from astrai.tokenize.chat_template import ChatTemplate
class BaseTokenizer(ABC):
@abstractmethod
def _init_tokenizer(self):
pass
class TextTokenizer:
"""Base tokenizer class with automatic loading support"""
@abstractmethod
def save(self, path):
pass
TOKENIZER_CLASSES = {} # Registry for auto-loading
@abstractmethod
def load(self, path):
pass
def __init__(
self,
path: Optional[Union[str, Path]] = None,
special_token_map: Optional[Dict[str, str]] = None,
chat_template: Optional[str] = None,
):
self._tokenizer: Tokenizer = None
self._chat_template: Optional[ChatTemplate] = None
self._special_token_map: Optional[Dict] = special_token_map or {}
if chat_template:
self.set_chat_template(chat_template)
if path:
self.load(path)
def load(self, path: Union[str, Path]):
"""Load tokenizer from directory."""
path = Path(path)
tokenizer_file = path / "tokenizer.json"
config_file = path / "tokenizer_config.json"
self._tokenizer = Tokenizer.from_file(str(tokenizer_file))
if config_file.exists():
with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)
if "special_tokens" in config:
self._special_token_map.update(config["special_tokens"])
# Load chat template from config
if "chat_template" in config:
self.set_chat_template(config["chat_template"])
@classmethod
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "TextTokenizer":
"""Load tokenizer from pretrained directory."""
instance = cls(path)
return instance
def save_pretrained(self, tokenizer, save_path: str):
"""
Save tokenizer to pretrained directory.
Args:
tokenizer: Tokenizer instance to save
save_path: Path to save the tokenizer
"""
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
self._tokenizer.save(tokenizer, save_path)
@classmethod
def register_tokenizer(cls, name: str, tokenizer_class: type):
"""
Register a new tokenizer class.
Args:
name: Name to register the tokenizer class under
tokenizer_class: The tokenizer class to register
"""
cls.TOKENIZER_CLASSES[name] = tokenizer_class
@abstractmethod
def encode(
self,
tokens: Union[str, List[str]],
out_ids: bool = True,
add_special_tokens: bool = False,
is_pretokenized: bool = False,
add_special_tokens: bool = True,
) -> List:
pass
"""Encode text to tokens or token IDs."""
if self._tokenizer is None:
raise RuntimeError(
"Tokenizer not initialized. Load or create a tokenizer first."
)
if isinstance(tokens, str):
encoded = self._tokenizer.encode(
tokens,
is_pretokenized=is_pretokenized,
add_special_tokens=add_special_tokens,
)
return encoded.ids if out_ids else encoded.tokens
else:
encoded_list = self._tokenizer.encode_batch(
tokens,
is_pretokenized=is_pretokenized,
add_special_tokens=add_special_tokens,
)
return [
encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
]
@abstractmethod
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
pass
"""Decode token IDs to text."""
if self._tokenizer is None:
raise RuntimeError(
"Tokenizer not initialized. Load or create a tokenizer first."
)
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
@abstractmethod
def __len__(self) -> int:
pass
if self._tokenizer is None:
return 0
return self._tokenizer.get_vocab_size()
def __getattr__(self, key: str):
"""
Dynamically intercept special token attribute access.
Supports three forms:
- tokenizer.bos_token returns string
- tokenizer.bos_token_id returns corresponding integer ID
- tokenizer.stop_ids returns list of corresponding integer IDs
"""
# Handle stop_ids
if key == "stop_ids":
return [
self._special_token_map.get(val)
for val in self._special_token_map.values()
]
# Handle _id suffix (e.g., bos_token_id -> bos_token)
if key.endswith("_id"):
base_attr = key[:-3] # Remove "_id"
token_str = self._special_token_map.get(base_attr)
if token_str is None:
return None
if self._tokenizer is None:
raise RuntimeError("Tokenizer not loaded, cannot convert token to id.")
return self._tokenizer.token_to_id(token_str)
# Handle regular string attributes
if key in self._special_token_map:
return self._special_token_map.get(key)
# Other attributes trigger default AttributeError
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")
@property
@abstractmethod
def stop_ids(self) -> List[int]:
pass
def vocab_size(self) -> int:
return len(self)
@property
@abstractmethod
def bos_id(self) -> int:
pass
def pad_id(self) -> Optional[int]:
"""Return the pad token ID if available."""
pad_token = self._special_token_map.get("pad")
if pad_token is None or self._tokenizer is None:
return None
return self._tokenizer.token_to_id(pad_token)
@property
@abstractmethod
def eos_id(self) -> int:
pass
def set_chat_template(self, template: Union[str, ChatTemplate]):
"""
Set the chat template for the tokenizer.
@property
@abstractmethod
def pad_id(self) -> int:
pass
Args:
template: Either a template name (str) registered in the global registry,
or a ChatTemplate instance, or a Jinja2 template string.
Raises:
KeyError: If template name is not registered.
"""
if isinstance(template, str):
self._chat_template = ChatTemplate.from_string(template)
elif isinstance(template, ChatTemplate):
self._chat_template = template
else:
raise ValueError("Invalid template type, must be str or ChatTemplate.")
def apply_chat_template(
self,
messages: List[Dict[str, str]],
system_prompt: Optional[str] = None,
tokenize: bool = True,
**kwargs,
) -> Union[str, List[int]]:
"""
Apply the chat template to messages and optionally tokenize the result.
Args:
messages: List of message dicts with 'role' and 'content'.
system_prompt: Optional system prompt string.
tokenize: Whether to return token IDs (True) or raw string (False).
**kwargs: Additional variables to pass to the template.
Returns:
Either the rendered string or list of token IDs.
Raises:
RuntimeError: If chat template is not set.
"""
if self._chat_template is None:
raise RuntimeError(
"Chat template not set. Use set_chat_template() to set a template first."
)
# Render the template
rendered = self._chat_template.render(
messages=messages,
system_prompt=system_prompt,
**kwargs,
)
if tokenize:
return self.encode(rendered)
return rendered
class BaseTrainer(ABC):
def __init__(self, tokenizer: BaseTokenizer):
self.tokenizer = tokenizer
class BpeTokenizer(TextTokenizer):
"""BPE tokenizer implementation."""
@abstractmethod
def train(self, files, vocab_size, min_freq, **kwargs):
pass
@abstractmethod
def train_from_iterator(self, iterator, vocab_size, min_freq, **kwargs):
pass
class BpeTokenizer(BaseTokenizer):
def __init__(
self,
control_tokens: List[str] = None,
special_tokens: List[str] = None,
path=None,
special_token_map: Dict[str, str] = None,
path: Optional[str] = None,
chat_template: Optional[str] = None,
):
self._control_tokens = control_tokens or [
"<begin▁of▁sentence>",
"<end▁of▁sentence>",
"<▁pad▁>",
]
self._special_tokens = special_tokens or [
"<im▁start>",
"<im▁end>",
]
special_token_map = special_token_map or {
"bos": "<begin▁of▁sentence>",
"eos": "<end▁of▁sentence>",
"pad": "<▁pad▁>",
"im_start": "<im▁start>",
"im_end": "<im▁end>",
}
self._tokenizer = None
self._init_tokenizer()
if path is not None:
self.load(path)
super().__init__(
path, special_token_map=special_token_map, chat_template=chat_template
)
def _init_tokenizer(self):
"""Initialize a new BPE tokenizer with default settings."""
model = BPE()
self._tokenizer = Tokenizer(model)
self._tokenizer.normalizer = normalizers.Sequence(
@ -105,108 +262,3 @@ class BpeTokenizer(BaseTokenizer):
)
self._tokenizer.decoder = decoders.ByteLevel()
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
def save(self, path):
self._tokenizer.save(path)
def load(self, path):
self._tokenizer = Tokenizer.from_file(path)
def encode(
self,
tokens: Union[str, List[str]],
out_ids: bool = True,
add_special_tokens: bool = False,
) -> List:
if isinstance(tokens, str):
encoded = self._tokenizer.encode(
tokens, add_special_tokens=add_special_tokens
)
return encoded.ids if out_ids else encoded.tokens
else:
encoded_list = self._tokenizer.encode_batch(
tokens, add_special_tokens=add_special_tokens
)
return [
encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
]
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def __len__(self) -> int:
return self._tokenizer.get_vocab_size()
@property
def stop_ids(self) -> List[int]:
stop_token = self._control_tokens + self._special_tokens
return [self._tokenizer.token_to_id(tok) for tok in stop_token]
@property
def bos_id(self) -> int:
return self._tokenizer.token_to_id(self._control_tokens[0])
@property
def eos_id(self) -> int:
return self._tokenizer.token_to_id(self._control_tokens[1])
@property
def pad_id(self) -> int:
return self._tokenizer.token_to_id(self._control_tokens[2])
class BpeTrainer(BaseTrainer):
def __init__(self, tokenizer: BaseTokenizer):
super().__init__(tokenizer)
def _prepare_trainer(
self,
vocab_size: int,
min_freq: int,
reserved_token_size: int,
max_token_length=18,
):
assert reserved_token_size > len(self.tokenizer._special_tokens)
reserved_tokens = [
f"<|reserve{i:02d}|>"
for i in range(reserved_token_size - len(self.tokenizer._special_tokens))
]
detail_vocab_size = vocab_size - (
len(reserved_tokens) + len(self.tokenizer._special_tokens)
)
alphabet = pre_tokenizers.ByteLevel.alphabet()
min_size = len(alphabet) + len(self.tokenizer._control_tokens)
assert detail_vocab_size > min_size
trainer = BpeTrainerImpl(
vocab_size=detail_vocab_size,
min_frequency=min_freq,
limit_alphabet=detail_vocab_size // 6,
max_token_length=max_token_length,
special_tokens=self.tokenizer._control_tokens,
initial_alphabet=alphabet,
show_progress=True,
)
return trainer, reserved_tokens
def train(self, files, vocab_size, min_freq, reserved_token_size=100, **kwargs):
trainer, reserved_tokens = self._prepare_trainer(
vocab_size, min_freq, reserved_token_size, **kwargs
)
self.tokenizer._tokenizer.train(files=files, trainer=trainer)
self.tokenizer._tokenizer.add_special_tokens(
self.tokenizer._special_tokens + reserved_tokens
)
def train_from_iterator(
self, iterator, vocab_size, min_freq, reserved_token_size=100, **kwargs
):
trainer, reserved_tokens = self._prepare_trainer(
vocab_size, min_freq, reserved_token_size, **kwargs
)
self.tokenizer._tokenizer.train_from_iterator(
iterator=iterator, trainer=trainer
)
self.tokenizer._tokenizer.add_special_tokens(
self.tokenizer._special_tokens + reserved_tokens
)

108
astrai/tokenize/trainer.py Normal file
View File

@ -0,0 +1,108 @@
"""
BPE Tokenizer Trainer module.
Provides training functionality for BPE tokenizers.
"""
from typing import List, Union
from tokenizers import pre_tokenizers
from tokenizers.trainers import BpeTrainer as BpeTrainerImpl
class BpeTrainer:
"""BPE tokenizer trainer."""
def __init__(self, tokenizer):
"""Initialize trainer with a tokenizer instance.
Args:
tokenizer: A BpeTokenizer instance
"""
self.tokenizer = tokenizer
def _prepare_trainer(
self,
vocab_size: int,
min_freq: int,
reserved_token_size: int,
max_token_length: int = 18,
):
"""Prepare the BPE trainer with proper configuration."""
assert reserved_token_size > len(self.tokenizer._special_tokens)
reserved_tokens = [
f"<|reserve{i:02d}|>"
for i in range(reserved_token_size - len(self.tokenizer._special_tokens))
]
detail_vocab_size = vocab_size - (
len(reserved_tokens) + len(self.tokenizer._special_tokens)
)
alphabet = pre_tokenizers.ByteLevel.alphabet()
min_size = len(alphabet) + len(self.tokenizer._control_tokens)
assert detail_vocab_size > min_size
trainer = BpeTrainerImpl(
vocab_size=detail_vocab_size,
min_frequency=min_freq,
limit_alphabet=detail_vocab_size // 6,
max_token_length=max_token_length,
special_tokens=self.tokenizer._control_tokens,
initial_alphabet=alphabet,
show_progress=True,
)
return trainer, reserved_tokens
def train(
self,
files: Union[str, List[str]],
vocab_size: int,
min_freq: int,
reserved_token_size: int = 100,
**kwargs,
):
"""Train tokenizer from files.
Args:
files: Path or list of paths to training files
vocab_size: Target vocabulary size
min_freq: Minimum frequency for tokens
reserved_token_size: Number of reserved tokens
**kwargs: Additional arguments
"""
trainer, reserved_tokens = self._prepare_trainer(
vocab_size, min_freq, reserved_token_size, **kwargs
)
self.tokenizer._tokenizer.train(files=files, trainer=trainer)
self.tokenizer._tokenizer.add_special_tokens(
self.tokenizer._special_tokens + reserved_tokens
)
def train_from_iterator(
self,
iterator,
vocab_size: int,
min_freq: int,
reserved_token_size: int = 100,
**kwargs,
):
"""Train tokenizer from iterator.
Args:
iterator: Iterator yielding training strings
vocab_size: Target vocabulary size
min_freq: Minimum frequency for tokens
reserved_token_size: Number of reserved tokens
**kwargs: Additional arguments
"""
trainer, reserved_tokens = self._prepare_trainer(
vocab_size, min_freq, reserved_token_size, **kwargs
)
self.tokenizer._tokenizer.train_from_iterator(
iterator=iterator, trainer=trainer
)
self.tokenizer._tokenizer.add_special_tokens(
self.tokenizer._special_tokens + reserved_tokens
)
__all__ = ["BpeTrainer"]

View File

@ -2,24 +2,32 @@ from pathlib import Path
import torch
from astrai.config.param_config import ModelParameter
from astrai.inference import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def generate_text():
param = ModelParameter.load(PARAMETER_ROOT, disable_init=True)
param.to(device="cuda", dtype=torch.bfloat16)
# Load model from pretrained
model = AutoModel.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16)
# Load tokenizer from pretrained
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT / "tokenizer")
query = input(">> ")
engine = InferenceEngine(param)
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
)
response = engine.generate(
prompt=query,
stream=False,
max_tokens=param.config.max_len,
max_tokens=2048,
temperature=0.8,
top_p=0.95,
top_k=50,

View File

@ -2,7 +2,7 @@ from pathlib import Path
import torch
from astrai.config.param_config import ModelParameter
from astrai.model import AutoModel
from astrai.inference import InferenceEngine
PROJECT_ROOT = Path(__file__).resolve().parents[2]
@ -10,8 +10,10 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def batch_generate():
param = ModelParameter.load(PARAMETER_ROOT, disable_init=True)
param.to(device="cuda", dtype=torch.bfloat16)
# Load model using AutoModel
model = AutoModel.from_pretrained(
PARAMETER_ROOT, device="cuda", dtype=torch.bfloat16
)
inputs = [
"你好",
@ -21,11 +23,14 @@ def batch_generate():
"请问什么是显卡",
]
engine = InferenceEngine(param)
engine = InferenceEngine(
model=model.model,
tokenizer=model.tokenizer,
)
responses = engine.generate(
prompt=inputs,
stream=False,
max_tokens=param.config.max_len,
max_tokens=model.config.max_len,
temperature=0.8,
top_p=0.95,
top_k=50,

View File

@ -1,32 +1,39 @@
from pathlib import Path
import torch
from astrai.config.param_config import ModelParameter
from astrai.inference import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def chat():
param = ModelParameter.load(PARAMETER_ROOT, disable_init=True)
param.to(device="cuda", dtype=torch.bfloat16)
model = AutoModel.from_pretrained(PARAMETER_ROOT)
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16)
history = []
engine = InferenceEngine(param)
messages = []
engine = InferenceEngine(model=model, tokenizer=tokenizer)
while True:
query = input(">> ")
if query == "!exit":
break
# Add user message
messages.append({"role": "user", "content": query})
# Generate response
full_response = ""
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
for token in engine.generate(
prompt=query,
prompt=prompt,
stream=True,
max_tokens=param.config.max_len,
max_tokens=model.config.max_len,
temperature=0.8,
top_p=0.95,
top_k=50,
@ -35,7 +42,8 @@ def chat():
full_response += token
print()
history.append((query, full_response.strip()))
# Add assistant response to messages
messages.append({"role": "assistant", "content": full_response.strip()})
if __name__ == "__main__":

View File

@ -3,7 +3,8 @@ import json
import torch
from astrai.config.param_config import ModelParameter
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
from astrai.inference import InferenceEngine
@ -17,9 +18,9 @@ def processor(
question_key: str,
response_key: str,
):
param = ModelParameter.load(model_dir, disable_init=True)
param.to(device="cuda", dtype=torch.bfloat16)
engine = InferenceEngine(param)
# Load model using AutoModel
model = AutoModel.from_pretrained(model_dir, device="cuda", dtype=torch.bfloat16)
engine = InferenceEngine(model=model.model, tokenizer=model.tokenizer)
with open(input_json_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f]
@ -29,7 +30,7 @@ def processor(
responses = engine.generate(
prompt=queries,
stream=False,
max_tokens=param.config.max_len,
max_tokens=model.config.max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
import tqdm
from torch import Tensor
from astrai.config.param_config import ModelParameter
from astrai.model import AutoModel
def compute_perplexity(
@ -20,7 +20,7 @@ def compute_perplexity(
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
"""
output = model(input_ids, input_mask)
output = model(input_ids, input_mask=input_mask)
logits = output["logits"]
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
@ -42,10 +42,9 @@ def compute_perplexity(
def process_file(
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
):
param = ModelParameter.load(model_dir, disable_init=True)
param.to(device="cuda", dtype=torch.bfloat16)
model = param.model
tokenizer = param.tokenizer
# Load model using AutoModel
model = AutoModel.from_pretrained(model_dir, device="cuda", dtype=torch.bfloat16)
tokenizer = model.tokenizer
with open(input_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f]
@ -54,7 +53,7 @@ def process_file(
encoded_texts = [tokenizer.encode(text) for text in texts]
output_data = []
for i in tqdm(
for i in tqdm.tqdm(
range(0, len(encoded_texts), batch_size), desc="Computing perplexity"
):
batch_encoded = encoded_texts[i : i + batch_size]
@ -72,7 +71,7 @@ def process_file(
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
perplexity = compute_perplexity(model, input_ids, input_mask)
perplexity = compute_perplexity(model.model, input_ids, input_mask)
for text, ppl in zip(batch_texts, perplexity):
output_data.append({text_key: text, "ppl": float(ppl.item())})

View File

@ -5,10 +5,12 @@ from functools import partial
import torch
import torch.nn as nn
import torch.optim as optim
import safetensors.torch as st
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.config import ModelParameter, TrainConfig
from astrai.config import ModelConfig, TrainConfig
from astrai.dataset import DatasetFactory
from astrai.model import Transformer
from astrai.parallel import get_rank
from astrai.trainer import SchedulerFactory, Trainer
@ -196,12 +198,23 @@ def train(
assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path)
parameter = ModelParameter.load(param_path)
# Load config
config = ModelConfig()
config_path = os.path.join(param_path, "config.json")
if os.path.exists(config_path):
config.load(config_path)
if window_size is None:
window_size = parameter.config.max_len
window_size = config.max_len
model = parameter.model
# Create bare Transformer (for training, no tokenizer needed)
model = Transformer(config)
# Load weights if available
weights_path = os.path.join(param_path, "model.safetensors")
if os.path.exists(weights_path):
state_dict = st.load_file(weights_path)
model.load_state_dict(state_dict, strict=False)
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}

View File

@ -5,7 +5,7 @@ from unittest.mock import MagicMock
import pytest
from fastapi.testclient import TestClient
from astrai.inference.server import app, _engine
from astrai.inference.server import app
@pytest.fixture

View File

@ -1,34 +0,0 @@
import os
import torch
from astrai.config.param_config import ModelParameter
def test_model_parameter(test_env):
save_dir = os.path.join(test_env["test_dir"], "save")
model_param = ModelParameter(
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
)
ModelParameter.save(model_param, save_dir)
assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))
assert os.path.exists(os.path.join(save_dir, "config.json"))
# transformer
def test_transformer(test_env):
model = test_env["model"]
input_ids = torch.randint(
0,
test_env["transformer_config"].vocab_size,
(4, test_env["transformer_config"].max_len),
)
output_logits = model(input_ids)["logits"]
target_shape = (
4,
test_env["transformer_config"].max_len,
test_env["transformer_config"].vocab_size,
)
assert output_logits.shape == target_shape