feat: 实现模型动态注册机制
This commit is contained in:
parent
ff43a2fab8
commit
fc278d17ab
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
# Files that MUST use LF (Unix/Linux execution)
|
# Files that MUST use LF (Unix/Linux execution)
|
||||||
*.sh text eol=lf
|
*.sh text eol=lf
|
||||||
|
*.py text eol=lf
|
||||||
Dockerfile text eol=lf
|
Dockerfile text eol=lf
|
||||||
.dockerignore text eol=lf
|
.dockerignore text eol=lf
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -90,8 +90,12 @@ from astrai.inference import InferenceEngine, GenerationRequest
|
||||||
param = ModelParameter.load("your_model_dir")
|
param = ModelParameter.load("your_model_dir")
|
||||||
param.to(device="cuda", dtype=torch.bfloat16)
|
param.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
# Create engine
|
# Create engine with separate model and tokenizer
|
||||||
engine = InferenceEngine(param)
|
engine = InferenceEngine(
|
||||||
|
model=param.model,
|
||||||
|
tokenizer=param.tokenizer,
|
||||||
|
config=param.config,
|
||||||
|
)
|
||||||
|
|
||||||
# Build request
|
# Build request
|
||||||
request = GenerationRequest(
|
request = GenerationRequest(
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from astrai.inference import (
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
)
|
)
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model import AutoModel, Transformer
|
||||||
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
|
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -27,4 +27,5 @@ __all__ = [
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
"SchedulerFactory",
|
"SchedulerFactory",
|
||||||
"BaseFactory",
|
"BaseFactory",
|
||||||
|
"AutoModel",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,7 @@
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
from astrai.config.param_config import BaseModelIO, ModelParameter
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Base I/O
|
|
||||||
"BaseModelIO",
|
|
||||||
"ModelParameter",
|
|
||||||
# Model configuration
|
# Model configuration
|
||||||
"ModelConfig",
|
"ModelConfig",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from typing import Optional, Self
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
# basic config
|
# basic config
|
||||||
|
model_type: Optional[str] = None
|
||||||
vocab_size: Optional[int] = None
|
vocab_size: Optional[int] = None
|
||||||
dim: Optional[int] = None
|
dim: Optional[int] = None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,114 +0,0 @@
|
||||||
from contextlib import contextmanager
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Self, Union
|
|
||||||
|
|
||||||
import safetensors.torch as st
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
|
||||||
from astrai.tokenize import BpeTokenizer
|
|
||||||
from astrai.model.transformer import Transformer
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def disable_random_init(enable: bool = True):
|
|
||||||
init_functions = [
|
|
||||||
"xavier_normal_",
|
|
||||||
"xavier_uniform_",
|
|
||||||
"kaiming_normal_",
|
|
||||||
"kaiming_uniform_",
|
|
||||||
"zeros_",
|
|
||||||
"ones_",
|
|
||||||
"constant_",
|
|
||||||
"normal_",
|
|
||||||
"uniform_",
|
|
||||||
]
|
|
||||||
original_funcs = {}
|
|
||||||
for name in init_functions:
|
|
||||||
if enable and hasattr(nn.init, name):
|
|
||||||
original_funcs[name] = getattr(nn.init, name)
|
|
||||||
setattr(nn.init, name, lambda *args, **kwargs: None)
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
if enable:
|
|
||||||
for name, orig_func in original_funcs.items():
|
|
||||||
setattr(nn.init, name, orig_func)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BaseModelIO:
|
|
||||||
"""Base class for model I/O operations."""
|
|
||||||
|
|
||||||
model: nn.Module = field(
|
|
||||||
default_factory=nn.Identity, metadata={"help": "Transformer model."}
|
|
||||||
)
|
|
||||||
tokenizer: BpeTokenizer = field(
|
|
||||||
default_factory=BpeTokenizer, metadata={"help": "Tokenizer for the model."}
|
|
||||||
)
|
|
||||||
config: ModelConfig = field(
|
|
||||||
default_factory=ModelConfig,
|
|
||||||
metadata={"help": "Transformer model configuration."},
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
|
||||||
"""Get standardized file paths for model components."""
|
|
||||||
dir_path = Path(directory)
|
|
||||||
return {
|
|
||||||
"model": dir_path / "model.safetensors",
|
|
||||||
"config": dir_path / "config.json",
|
|
||||||
"tokenizer": dir_path / "tokenizer.json",
|
|
||||||
}
|
|
||||||
|
|
||||||
def save_components(self, save_dir: Union[str, Path]):
|
|
||||||
"""Save core model components."""
|
|
||||||
paths = self._get_file_paths(save_dir)
|
|
||||||
paths["model"].parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
if self.model is not None:
|
|
||||||
st.save_file(self.model.state_dict(), str(paths["model"]))
|
|
||||||
|
|
||||||
self.config.save(str(paths["config"]))
|
|
||||||
self.tokenizer.save(str(paths["tokenizer"]))
|
|
||||||
|
|
||||||
def load_components(
|
|
||||||
self, load_dir: Union[str, Path], disable_init: bool = False
|
|
||||||
) -> Self:
|
|
||||||
"""Load core model components."""
|
|
||||||
paths = self._get_file_paths(load_dir)
|
|
||||||
|
|
||||||
self.config.load(str(paths["config"]))
|
|
||||||
self.tokenizer.load(str(paths["tokenizer"]))
|
|
||||||
|
|
||||||
if isinstance(self.model, nn.Identity):
|
|
||||||
with disable_random_init(enable=disable_init):
|
|
||||||
self.model = Transformer(self.config)
|
|
||||||
|
|
||||||
if paths["model"].exists():
|
|
||||||
state_dict = st.load_file(str(paths["model"]))
|
|
||||||
self.model.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def to(self, *args, **kwargs) -> "BaseModelIO":
|
|
||||||
"""Move model to device."""
|
|
||||||
if self.model is not None:
|
|
||||||
self.model.to(*args, **kwargs)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelParameter(BaseModelIO):
|
|
||||||
"""Container for model parameters with serialization capabilities."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]):
|
|
||||||
instance.save_components(save_dir)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(
|
|
||||||
cls, load_dir: Union[str, Path], disable_init: bool = False
|
|
||||||
) -> "ModelParameter":
|
|
||||||
instance = cls()
|
|
||||||
return instance.load_components(load_dir, disable_init=disable_init)
|
|
||||||
|
|
@ -185,3 +185,6 @@ class BaseFactory(ABC, Generic[T]):
|
||||||
def list_by_priority(cls, reverse: bool = False) -> List[str]:
|
def list_by_priority(cls, reverse: bool = False) -> List[str]:
|
||||||
"""List registered component names sorted by priority."""
|
"""List registered component names sorted by priority."""
|
||||||
return cls._registry.list_by_priority(reverse)
|
return cls._registry.list_by_priority(reverse)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Registry", "BaseFactory"]
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
"""Unified inference engine."""
|
"""Unified inference engine."""
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from typing import Any, Dict, Generator, List, Optional, Union
|
from typing import Any, Dict, Generator, List, Optional, Union
|
||||||
|
|
||||||
from astrai.config import ModelParameter
|
from astrai.tokenize.tokenizer import TextTokenizer
|
||||||
from astrai.tokenize.chat_template import build_prompt
|
|
||||||
|
|
||||||
from astrai.inference.scheduler import InferenceScheduler
|
from astrai.inference.scheduler import InferenceScheduler
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -14,22 +14,18 @@ class GenerationRequest:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
query: Union[str, List[str]],
|
messages: List[Dict[str, str]],
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
max_len: int = 1024,
|
max_len: int = 1024,
|
||||||
history: Optional[Any] = None,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
self.query = query
|
self.messages = messages
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
self.history = history
|
|
||||||
self.system_prompt = system_prompt
|
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
|
||||||
self._validate()
|
self._validate()
|
||||||
|
|
@ -107,26 +103,41 @@ class InferenceEngine:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parameter: ModelParameter,
|
model: nn.Module,
|
||||||
max_batch_size: int = 16,
|
tokenizer: TextTokenizer,
|
||||||
|
max_batch_size: int = 1,
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.model = parameter.model
|
"""
|
||||||
self.tokenizer = parameter.tokenizer
|
Initialize inference engine with separate model and tokenizer.
|
||||||
self.config = parameter.config
|
|
||||||
|
|
||||||
model_params = next(self.model.parameters())
|
Args:
|
||||||
self.device = model_params.device
|
model: The language model for inference (nn.Module, e.g., Transformer)
|
||||||
self.dtype = model_params.dtype
|
tokenizer: The tokenizer for encoding/decoding text
|
||||||
|
config: Model configuration
|
||||||
|
max_batch_size: Maximum batch size for continuous batching
|
||||||
|
max_seq_len: Maximum sequence length (defaults to config.max_len)
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
# Get device and dtype from model parameters
|
||||||
|
try:
|
||||||
|
first_param = next(model.parameters())
|
||||||
|
device = first_param.device
|
||||||
|
dtype = first_param.dtype
|
||||||
|
except StopIteration:
|
||||||
|
# Model has no parameters, use default device/dtype
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
self.scheduler = InferenceScheduler(
|
self.scheduler = InferenceScheduler(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
config=self.config,
|
|
||||||
max_batch_size=max_batch_size,
|
max_batch_size=max_batch_size,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
device=self.device,
|
device=device,
|
||||||
dtype=self.dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.kv_cache = self.scheduler.kv_cache
|
self.kv_cache = self.scheduler.kv_cache
|
||||||
|
|
@ -160,7 +171,8 @@ class InferenceEngine:
|
||||||
self, request: GenerationRequest
|
self, request: GenerationRequest
|
||||||
) -> Union[Generator[str, None, None], str, List[str]]:
|
) -> Union[Generator[str, None, None], str, List[str]]:
|
||||||
"""Generate with GenerationRequest object."""
|
"""Generate with GenerationRequest object."""
|
||||||
prompt = build_prompt(request.query, request.history)
|
# Use tokenizer's chat template with messages
|
||||||
|
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
|
||||||
|
|
||||||
return self.generate(
|
return self.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,8 @@ from typing import Any, Callable, Dict, List, Optional
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.config import ModelConfig
|
from astrai.model.automodel import AutoModel
|
||||||
|
from astrai.tokenize.tokenizer import TextTokenizer
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus:
|
class TaskStatus:
|
||||||
|
|
@ -98,23 +99,23 @@ class InferenceScheduler:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model: AutoModel,
|
||||||
tokenizer,
|
tokenizer: TextTokenizer,
|
||||||
config: ModelConfig,
|
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
):
|
):
|
||||||
|
config = model.config
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.config = config
|
|
||||||
self.max_batch_size = max_batch_size
|
self.max_batch_size = max_batch_size
|
||||||
self.max_seq_len = max_seq_len or config.max_len
|
self.max_seq_len = max_seq_len or config.max_len
|
||||||
self.device = device
|
self.device = device or next(model.parameters()).device
|
||||||
self.dtype = dtype
|
self.dtype = dtype or next(model.parameters()).dtype
|
||||||
|
|
||||||
num_heads = config.n_kv_heads
|
num_kv_heads = config.n_kv_heads
|
||||||
head_dim = config.dim // config.n_heads
|
head_dim = config.dim // config.n_heads
|
||||||
n_layers = config.n_layers
|
n_layers = config.n_layers
|
||||||
|
|
||||||
|
|
@ -123,26 +124,26 @@ class InferenceScheduler:
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
self.max_seq_len,
|
self.max_seq_len,
|
||||||
n_layers,
|
n_layers,
|
||||||
num_heads,
|
num_kv_heads,
|
||||||
head_dim,
|
head_dim,
|
||||||
),
|
),
|
||||||
device=device,
|
device=self.device,
|
||||||
dtype=dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
v_cache = torch.empty(
|
v_cache = torch.empty(
|
||||||
(
|
(
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
self.max_seq_len,
|
self.max_seq_len,
|
||||||
n_layers,
|
n_layers,
|
||||||
num_heads,
|
num_kv_heads,
|
||||||
head_dim,
|
head_dim,
|
||||||
),
|
),
|
||||||
device=device,
|
device=self.device,
|
||||||
dtype=dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.kv_cache = (k_cache, v_cache)
|
self.kv_cache = (k_cache, v_cache)
|
||||||
self.seq_mask = torch.ones(
|
self.seq_mask = torch.ones(
|
||||||
(max_batch_size, self.max_seq_len), device=device, dtype=torch.bool
|
(max_batch_size, self.max_seq_len), device=self.device, dtype=torch.bool
|
||||||
)
|
)
|
||||||
|
|
||||||
self.waiting_queue: List[Task] = []
|
self.waiting_queue: List[Task] = []
|
||||||
|
|
@ -259,7 +260,7 @@ class InferenceScheduler:
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
outputs = self.model(
|
self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
input_mask=input_mask,
|
input_mask=input_mask,
|
||||||
start_pos=0,
|
start_pos=0,
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,6 @@ Inference Server with Continuous Batching Support
|
||||||
|
|
||||||
FastAPI server for inference with continuous batching.
|
FastAPI server for inference with continuous batching.
|
||||||
Provides OpenAI-compatible chat completion endpoints.
|
Provides OpenAI-compatible chat completion endpoints.
|
||||||
|
|
||||||
Author: AstrAI Team
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
@ -19,14 +17,15 @@ from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.inference.engine import InferenceEngine
|
||||||
from astrai.inference.engine import GenerationRequest, InferenceEngine
|
from astrai.model import AutoModel
|
||||||
|
from astrai.tokenize import TextTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Global model parameter and engine (loaded once)
|
# Global model parameter and engine (loaded once)
|
||||||
_model_param: Optional[ModelParameter] = None
|
|
||||||
_engine: Optional[InferenceEngine] = None
|
_engine: Optional[InferenceEngine] = None
|
||||||
|
_model_param: Optional[Any] = None
|
||||||
_project_root = Path(__file__).parent.parent.parent
|
_project_root = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
# Server configuration (set before running server)
|
# Server configuration (set before running server)
|
||||||
|
|
@ -95,13 +94,17 @@ def load_model(
|
||||||
param_path = _project_root / "params"
|
param_path = _project_root / "params"
|
||||||
if not param_path.exists():
|
if not param_path.exists():
|
||||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||||
_model_param = ModelParameter.load(param_path, disable_init=True)
|
|
||||||
|
# Load tokenizer separately
|
||||||
|
tokenizer = TextTokenizer.from_pretrained(param_path)
|
||||||
|
_model_param = AutoModel.from_pretrained(param_path, tokenizer=tokenizer)
|
||||||
_model_param.to(device=device, dtype=dtype)
|
_model_param.to(device=device, dtype=dtype)
|
||||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||||
|
|
||||||
# Initialize inference engine with continuous batching
|
# Initialize inference engine with separate model and tokenizer
|
||||||
_engine = InferenceEngine(
|
_engine = InferenceEngine(
|
||||||
parameter=_model_param,
|
model=_model_param,
|
||||||
|
tokenizer=tokenizer,
|
||||||
max_batch_size=max_batch_size,
|
max_batch_size=max_batch_size,
|
||||||
)
|
)
|
||||||
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
||||||
|
|
@ -164,27 +167,43 @@ def convert_messages_to_history(
|
||||||
return system_prompt, history if history else None
|
return system_prompt, history if history else None
|
||||||
|
|
||||||
|
|
||||||
def convert_messages_to_prompt(messages: List[ChatMessage]) -> str:
|
def convert_messages_to_prompt(
|
||||||
|
messages: List[ChatMessage], engine: InferenceEngine = None
|
||||||
|
) -> str:
|
||||||
"""Convert messages to prompt string.
|
"""Convert messages to prompt string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of ChatMessage objects
|
messages: List of ChatMessage objects
|
||||||
|
engine: InferenceEngine instance for accessing tokenizer
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Formatted prompt string
|
str: Formatted prompt string
|
||||||
"""
|
"""
|
||||||
system_prompt, history = convert_messages_to_history(messages)
|
# Convert to dict format for chat template
|
||||||
|
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
|
||||||
|
|
||||||
# Get the last user message as query
|
# Extract system prompt if present
|
||||||
user_messages = [m.content for m in messages if m.role == "user"]
|
system_prompt = None
|
||||||
if not user_messages:
|
filtered_messages = []
|
||||||
raise ValueError("No user message found")
|
for msg in msg_dicts:
|
||||||
query = user_messages[-1]
|
if msg["role"] == "system":
|
||||||
|
system_prompt = msg["content"]
|
||||||
|
else:
|
||||||
|
filtered_messages.append(msg)
|
||||||
|
|
||||||
# Build prompt using chat template
|
# Use engine's tokenizer chat template if available
|
||||||
from astrai.tokenize.chat_template import build_prompt
|
if engine is not None and engine.tokenizer is not None:
|
||||||
|
return engine.tokenizer.apply_chat_template(
|
||||||
|
filtered_messages, system_prompt=system_prompt, tokenize=False
|
||||||
|
)
|
||||||
|
|
||||||
return build_prompt(query, history)
|
# Fallback: simple concatenation (deprecated)
|
||||||
|
prompt_parts = []
|
||||||
|
for msg in filtered_messages:
|
||||||
|
prompt_parts.append(
|
||||||
|
f"<|im▁start|>{msg['role']}\n{msg['content']}<|im▁end|>"
|
||||||
|
)
|
||||||
|
return "\n".join(prompt_parts) + "\n<|im▁start|>assistant\n"
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
|
|
@ -213,8 +232,8 @@ async def chat_completion(request: ChatCompletionRequest):
|
||||||
if _engine is None:
|
if _engine is None:
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
|
|
||||||
# Convert messages to prompt
|
# Convert messages to prompt using engine's tokenizer
|
||||||
prompt = convert_messages_to_prompt(request.messages)
|
prompt = convert_messages_to_prompt(request.messages, engine=_engine)
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
# Streaming response (use synchronous generator)
|
# Streaming response (use synchronous generator)
|
||||||
|
|
@ -294,15 +313,18 @@ async def generate(
|
||||||
if _engine is None:
|
if _engine is None:
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
|
|
||||||
# Convert history format
|
# Build messages for chat template
|
||||||
hist: Optional[List[Tuple[str, str]]] = None
|
messages = []
|
||||||
if history:
|
if history:
|
||||||
hist = [(h[0], h[1]) for h in history]
|
# Convert history format: List[List[str]] -> List[Dict]
|
||||||
|
for h in history:
|
||||||
|
if len(h) >= 2:
|
||||||
|
messages.append({"role": "user", "content": h[0]})
|
||||||
|
messages.append({"role": "assistant", "content": h[1]})
|
||||||
|
messages.append({"role": "user", "content": query})
|
||||||
|
|
||||||
# Build prompt
|
# Use tokenizer's chat template
|
||||||
from astrai.tokenize.chat_template import build_prompt
|
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
|
||||||
prompt = build_prompt(query, hist)
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
# Synchronous streaming
|
# Synchronous streaming
|
||||||
|
|
|
||||||
|
|
@ -6,5 +6,17 @@ from astrai.model.module import (
|
||||||
RMSNorm,
|
RMSNorm,
|
||||||
)
|
)
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
|
from astrai.model.automodel import AutoModel
|
||||||
|
|
||||||
__all__ = ["Linear", "RMSNorm", "MLP", "GQA", "DecoderBlock", "Transformer"]
|
|
||||||
|
__all__ = [
|
||||||
|
# Modules
|
||||||
|
"Linear",
|
||||||
|
"RMSNorm",
|
||||||
|
"MLP",
|
||||||
|
"GQA",
|
||||||
|
"DecoderBlock",
|
||||||
|
# Models
|
||||||
|
"Transformer",
|
||||||
|
"AutoModel",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,134 @@
|
||||||
|
"""
|
||||||
|
AutoModel base class for model loading and saving.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
import safetensors.torch as st
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Self, Union, Dict, Type
|
||||||
|
|
||||||
|
from astrai.config import ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _disable_random_init(enable: bool = True):
|
||||||
|
init_functions = [
|
||||||
|
"xavier_normal_",
|
||||||
|
"xavier_uniform_",
|
||||||
|
"kaiming_normal_",
|
||||||
|
"kaiming_uniform_",
|
||||||
|
"zeros_",
|
||||||
|
"ones_",
|
||||||
|
"constant_",
|
||||||
|
"normal_",
|
||||||
|
"uniform_",
|
||||||
|
]
|
||||||
|
original_funcs = {}
|
||||||
|
for name in init_functions:
|
||||||
|
if enable and hasattr(nn.init, name):
|
||||||
|
original_funcs[name] = getattr(nn.init, name)
|
||||||
|
setattr(nn.init, name, lambda *args, **kwargs: None)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if enable:
|
||||||
|
for name, orig_func in original_funcs.items():
|
||||||
|
setattr(nn.init, name, orig_func)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModel(nn.Module):
|
||||||
|
"""
|
||||||
|
Autoregressive language model base class.
|
||||||
|
Provides model loading/saving and generation capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Model registry - stored as class attribute
|
||||||
|
_registry: Dict[str, Type["AutoModel"]] = {}
|
||||||
|
|
||||||
|
def __init__(self, config: ModelConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, model_type: str):
|
||||||
|
"""
|
||||||
|
Class method decorator to register model type.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@AutoModel.register('transformer')
|
||||||
|
class Transformer(AutoModel):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
|
||||||
|
cls._registry[model_type.lower()] = sub_cls
|
||||||
|
return sub_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
|
||||||
|
"""Get model class by model_type string."""
|
||||||
|
model_type = model_type.lower()
|
||||||
|
if model_type not in cls._registry:
|
||||||
|
available = list(cls._registry.keys())
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown model_type: {model_type}. Available: {available}"
|
||||||
|
)
|
||||||
|
return cls._registry[model_type]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
path: Union[str, Path],
|
||||||
|
disable_random_init: bool = True,
|
||||||
|
) -> nn.Module:
|
||||||
|
|
||||||
|
model_path = Path(path)
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
config = ModelConfig()
|
||||||
|
config_path = model_path / "config.json"
|
||||||
|
if config_path.exists():
|
||||||
|
config.load(str(config_path))
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||||
|
|
||||||
|
# If called from base class, use model_type to determine actual model class
|
||||||
|
if cls is AutoModel:
|
||||||
|
model_type = config.model_type or "transformer"
|
||||||
|
actual_cls = cls.get_model_class(model_type)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot call from_pretrained() on subclass {cls.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
with _disable_random_init(enable=disable_random_init):
|
||||||
|
model = actual_cls(config)
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
weights_path = model_path / "model.safetensors"
|
||||||
|
if weights_path.exists():
|
||||||
|
state_dict = st.load_file(str(weights_path))
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def save_pretrained(
|
||||||
|
self,
|
||||||
|
save_directory: Union[str, Path],
|
||||||
|
) -> None:
|
||||||
|
save_path = Path(save_directory)
|
||||||
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save config
|
||||||
|
self.config.save(str(save_path / "config.json"))
|
||||||
|
|
||||||
|
# Save weights
|
||||||
|
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs) -> Self:
|
||||||
|
"""Move model to device/dtype."""
|
||||||
|
return super().to(*args, **kwargs)
|
||||||
|
|
@ -5,6 +5,7 @@ import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.model.module import (
|
from astrai.model.module import (
|
||||||
DecoderBlock,
|
DecoderBlock,
|
||||||
Embedding,
|
Embedding,
|
||||||
|
|
@ -66,9 +67,14 @@ def process_attention_mask(
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
class Transformer(nn.Module):
|
@AutoModel.register("transformer")
|
||||||
|
class Transformer(AutoModel):
|
||||||
|
"""
|
||||||
|
Transformer language model.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: ModelConfig):
|
||||||
super().__init__()
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rotary_embeding = RotaryEmbedding(
|
self.rotary_embeding = RotaryEmbedding(
|
||||||
config.dim // config.n_heads, config.max_len
|
config.dim // config.n_heads, config.max_len
|
||||||
|
|
@ -97,16 +103,27 @@ class Transformer(nn.Module):
|
||||||
if self.config.tie_weight:
|
if self.config.tie_weight:
|
||||||
self.lm_head.weight = self.embed_tokens.weight
|
self.lm_head.weight = self.embed_tokens.weight
|
||||||
|
|
||||||
self._init_parameters()
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
for param in self.parameters():
|
||||||
|
if param.dim() > 1:
|
||||||
|
nn.init.normal_(param, mean=0.0, std=0.006)
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||||
lm_head_key = "lm_head.weight"
|
lm_head_key = "lm_head.weight"
|
||||||
embed_key = "embed_tokens.weight"
|
embed_key = "embed_tokens.weight"
|
||||||
|
|
||||||
|
# Make a copy to avoid modifying the original state_dict
|
||||||
|
state_dict = dict(state_dict)
|
||||||
|
|
||||||
if self.config.tie_weight:
|
if self.config.tie_weight:
|
||||||
# same tensor
|
# same tensor
|
||||||
state_dict[lm_head_key] = state_dict[embed_key]
|
if embed_key in state_dict:
|
||||||
|
state_dict[lm_head_key] = state_dict[embed_key]
|
||||||
else:
|
else:
|
||||||
|
# If lm_head.weight exists in checkpoint, use it directly
|
||||||
|
# If not, copy from embed_tokens.weight
|
||||||
if lm_head_key not in state_dict and embed_key in state_dict:
|
if lm_head_key not in state_dict and embed_key in state_dict:
|
||||||
# use clone to avoid sharing the same tensor
|
# use clone to avoid sharing the same tensor
|
||||||
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
|
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
|
||||||
|
|
@ -125,11 +142,6 @@ class Transformer(nn.Module):
|
||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def _init_parameters(self):
|
|
||||||
for param in self.parameters():
|
|
||||||
if param.dim() > 1:
|
|
||||||
nn.init.normal_(param, mean=0.0, std=0.006)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,23 @@
|
||||||
from astrai.tokenize.tokenizer import (
|
from astrai.tokenize.tokenizer import (
|
||||||
BaseTokenizer,
|
TextTokenizer,
|
||||||
BpeTokenizer,
|
BpeTokenizer,
|
||||||
BaseTrainer,
|
|
||||||
BpeTrainer,
|
|
||||||
)
|
)
|
||||||
|
from astrai.tokenize.trainer import BpeTrainer
|
||||||
from astrai.tokenize.chat_template import (
|
from astrai.tokenize.chat_template import (
|
||||||
|
ChatTemplate,
|
||||||
HistoryType,
|
HistoryType,
|
||||||
MessageType,
|
MessageType,
|
||||||
build_prompt,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Alias for compatibility
|
||||||
|
AutoTokenizer = TextTokenizer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseTokenizer",
|
"TextTokenizer",
|
||||||
|
"AutoTokenizer",
|
||||||
"BpeTokenizer",
|
"BpeTokenizer",
|
||||||
"BaseTrainer",
|
|
||||||
"BpeTrainer",
|
"BpeTrainer",
|
||||||
|
"ChatTemplate",
|
||||||
"HistoryType",
|
"HistoryType",
|
||||||
"MessageType",
|
"MessageType",
|
||||||
"CHAT_TEMPLATES",
|
|
||||||
"build_prompt",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
from typing import Dict, List, Optional, Tuple, Any
|
from typing import Dict, List, Optional, Tuple, Any
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from astrai.factory import Registry
|
|
||||||
|
|
||||||
HistoryType = List[Tuple[str, str]]
|
HistoryType = List[Tuple[str, str]]
|
||||||
MessageType = Dict[str, str]
|
MessageType = Dict[str, str]
|
||||||
|
|
@ -75,213 +74,3 @@ class ChatTemplate:
|
||||||
|
|
||||||
jinja_template = Template(self.template_str)
|
jinja_template = Template(self.template_str)
|
||||||
return jinja_template.render(**variables)
|
return jinja_template.render(**variables)
|
||||||
|
|
||||||
|
|
||||||
# Global registry instance
|
|
||||||
_default_registry = Registry()
|
|
||||||
|
|
||||||
# Default template name
|
|
||||||
_default_template_name = "chatml"
|
|
||||||
|
|
||||||
|
|
||||||
# Convenience functions
|
|
||||||
def register_chat_template(
|
|
||||||
name: str,
|
|
||||||
template_str: str,
|
|
||||||
description: str = "",
|
|
||||||
default_variables: Optional[Dict[str, Any]] = None,
|
|
||||||
special_tokens: Optional[Dict[str, str]] = None,
|
|
||||||
) -> ChatTemplate:
|
|
||||||
"""Register a chat template in the global registry."""
|
|
||||||
template = ChatTemplate(
|
|
||||||
name=name,
|
|
||||||
template_str=template_str,
|
|
||||||
description=description,
|
|
||||||
default_variables=default_variables,
|
|
||||||
special_tokens=special_tokens,
|
|
||||||
)
|
|
||||||
_default_registry.register(name, template, category=None, priority=0)
|
|
||||||
return template
|
|
||||||
|
|
||||||
|
|
||||||
def set_default_chat_template(name: str) -> None:
|
|
||||||
"""Set the default chat template name globally."""
|
|
||||||
global _default_template_name
|
|
||||||
if not _default_registry.contains(name):
|
|
||||||
raise KeyError(
|
|
||||||
f"Chat template '{name}' not found. Available: {list(_default_registry.list_names())}"
|
|
||||||
)
|
|
||||||
_default_template_name = name
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_chat_template_name() -> str:
|
|
||||||
"""Get the current default chat template name."""
|
|
||||||
return _default_template_name
|
|
||||||
|
|
||||||
|
|
||||||
def get_chat_template(name: str) -> ChatTemplate:
|
|
||||||
"""Get a chat template from the global registry."""
|
|
||||||
return _default_registry.get(name)
|
|
||||||
|
|
||||||
|
|
||||||
def list_chat_templates() -> List[str]:
|
|
||||||
"""List all registered chat template names."""
|
|
||||||
return _default_registry.list_names()
|
|
||||||
|
|
||||||
|
|
||||||
def chat_template_exists(name: str) -> bool:
|
|
||||||
"""Check if a chat template exists."""
|
|
||||||
return _default_registry.contains(name)
|
|
||||||
|
|
||||||
|
|
||||||
def build_prompt(
|
|
||||||
query: str,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
history: Optional[HistoryType] = None,
|
|
||||||
template: Optional[str] = None,
|
|
||||||
template_name: Optional[str] = None,
|
|
||||||
**extra_variables: Any,
|
|
||||||
) -> str:
|
|
||||||
"""Build prompt using a registered chat template or a custom template string.
|
|
||||||
|
|
||||||
This function maintains backward compatibility with the previous API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: The current user query.
|
|
||||||
system_prompt: Optional system prompt.
|
|
||||||
history: Optional list of (user_msg, assistant_msg) pairs.
|
|
||||||
template: If provided, uses this exact Jinja2 template string (overrides template_name).
|
|
||||||
template_name: Name of a registered template to use (ignored if `template` is given).
|
|
||||||
If None, uses the globally set default template (see `set_default_chat_template`).
|
|
||||||
**extra_variables: Additional variables to pass to the template.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Rendered prompt string.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
KeyError: If `template_name` is not registered.
|
|
||||||
"""
|
|
||||||
# Convert history to message format
|
|
||||||
messages: List[MessageType] = []
|
|
||||||
if history:
|
|
||||||
for user_msg, assistant_msg in history:
|
|
||||||
messages.append({"role": "user", "content": user_msg})
|
|
||||||
messages.append({"role": "assistant", "content": assistant_msg})
|
|
||||||
messages.append({"role": "user", "content": query})
|
|
||||||
|
|
||||||
if template is not None:
|
|
||||||
# Use the provided template string directly
|
|
||||||
jinja_template = Template(template)
|
|
||||||
variables = {"messages": messages, **extra_variables}
|
|
||||||
if system_prompt is not None:
|
|
||||||
variables["system_prompt"] = system_prompt
|
|
||||||
return jinja_template.render(**variables)
|
|
||||||
else:
|
|
||||||
# Determine which template name to use
|
|
||||||
if template_name is None:
|
|
||||||
template_name = _default_template_name
|
|
||||||
# Use a registered template
|
|
||||||
chat_template = get_chat_template(template_name)
|
|
||||||
return chat_template.render(
|
|
||||||
messages=messages,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
**extra_variables,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Predefined templates
|
|
||||||
# ChatML template (original)
|
|
||||||
register_chat_template(
|
|
||||||
name="chatml",
|
|
||||||
template_str=(
|
|
||||||
"{%- if system_prompt -%}\n"
|
|
||||||
"{{ bos_token }}system\n"
|
|
||||||
"{{ system_prompt }}{{ eos_token }}\n"
|
|
||||||
"{%- endif -%}\n"
|
|
||||||
"{%- for message in messages -%}\n"
|
|
||||||
"{{ bos_token }}{{ message['role'] }}\n"
|
|
||||||
"{{ message['content'] }}{{ eos_token }}\n"
|
|
||||||
"{%- endfor -%}\n"
|
|
||||||
"{{ bos_token }}assistant\n"
|
|
||||||
),
|
|
||||||
description="ChatML format with configurable special tokens.",
|
|
||||||
special_tokens={"bos_token": "<|im▁start|>", "eos_token": "<|im▁end|>"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Simplified template without special tokens (plain text)
|
|
||||||
register_chat_template(
|
|
||||||
name="plain",
|
|
||||||
template_str=(
|
|
||||||
"{%- if system_prompt -%}\n"
|
|
||||||
"System: {{ system_prompt }}\n"
|
|
||||||
"{%- endif -%}\n"
|
|
||||||
"{%- for message in messages -%}\n"
|
|
||||||
"{{ message['role']|capitalize }}: {{ message['content'] }}\n"
|
|
||||||
"{%- endfor -%}\n"
|
|
||||||
"Assistant:"
|
|
||||||
),
|
|
||||||
description="Plain text format with role labels.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Alpaca-style template
|
|
||||||
register_chat_template(
|
|
||||||
name="alpaca",
|
|
||||||
template_str=(
|
|
||||||
"{%- if system_prompt -%}\n"
|
|
||||||
"### Instruction:\n"
|
|
||||||
"{{ system_prompt }}\n"
|
|
||||||
"{%- endif -%}\n"
|
|
||||||
"### Input:\n"
|
|
||||||
"{{ messages[-1]['content'] }}\n"
|
|
||||||
"### Response:"
|
|
||||||
),
|
|
||||||
description="Alpaca instruction‑response format (single‑turn).",
|
|
||||||
default_variables={},
|
|
||||||
)
|
|
||||||
|
|
||||||
# OpenAI chat format (approximation)
|
|
||||||
register_chat_template(
|
|
||||||
name="openai",
|
|
||||||
template_str=(
|
|
||||||
"{%- if system_prompt -%}\n"
|
|
||||||
"{{ bos_token }}system\n"
|
|
||||||
"{{ system_prompt }}{{ eos_token }}\n"
|
|
||||||
"{%- endif -%}\n"
|
|
||||||
"{%- for message in messages -%}\n"
|
|
||||||
"{{ bos_token }}{{ message['role'] }}\n"
|
|
||||||
"{{ message['content'] }}{{ eos_token }}\n"
|
|
||||||
"{%- endfor -%}\n"
|
|
||||||
"{{ bos_token }}assistant\n"
|
|
||||||
),
|
|
||||||
description="OpenAI‑compatible chat format with configurable special tokens.",
|
|
||||||
special_tokens={"bos_token": "<|im▁start|>", "eos_token": "<|im▁end|>"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Llama‑2 style with [INST] tags
|
|
||||||
register_chat_template(
|
|
||||||
name="llama2",
|
|
||||||
template_str=(
|
|
||||||
"{%- if system_prompt -%}\n"
|
|
||||||
"<<SYS>>\n"
|
|
||||||
"{{ system_prompt }}\n"
|
|
||||||
"<</SYS>>\n"
|
|
||||||
"{%- endif -%}\n"
|
|
||||||
"[INST] {{ messages[-1]['content'] }} [/INST]"
|
|
||||||
),
|
|
||||||
description="Llama‑2 style with [INST] tags (single‑turn).",
|
|
||||||
default_variables={},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ChatTemplate",
|
|
||||||
"register_chat_template",
|
|
||||||
"get_chat_template",
|
|
||||||
"list_chat_templates",
|
|
||||||
"chat_template_exists",
|
|
||||||
"build_prompt",
|
|
||||||
"set_default_chat_template",
|
|
||||||
"get_default_chat_template_name",
|
|
||||||
"HistoryType",
|
|
||||||
"MessageType",
|
|
||||||
]
|
|
||||||
|
|
|
||||||
|
|
@ -1,97 +1,254 @@
|
||||||
from abc import ABC, abstractmethod
|
"""
|
||||||
from typing import List, Union
|
Tokenizer module with implementation and auto-loading support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
||||||
from tokenizers.models import BPE
|
from tokenizers.models import BPE
|
||||||
from tokenizers.trainers import BpeTrainer as BpeTrainerImpl
|
from astrai.tokenize.chat_template import ChatTemplate
|
||||||
|
|
||||||
|
|
||||||
class BaseTokenizer(ABC):
|
class TextTokenizer:
|
||||||
@abstractmethod
|
"""Base tokenizer class with automatic loading support"""
|
||||||
def _init_tokenizer(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
TOKENIZER_CLASSES = {} # Registry for auto-loading
|
||||||
def save(self, path):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
def __init__(
|
||||||
def load(self, path):
|
self,
|
||||||
pass
|
path: Optional[Union[str, Path]] = None,
|
||||||
|
special_token_map: Optional[Dict[str, str]] = None,
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self._tokenizer: Tokenizer = None
|
||||||
|
self._chat_template: Optional[ChatTemplate] = None
|
||||||
|
self._special_token_map: Optional[Dict] = special_token_map or {}
|
||||||
|
|
||||||
|
if chat_template:
|
||||||
|
self.set_chat_template(chat_template)
|
||||||
|
|
||||||
|
if path:
|
||||||
|
self.load(path)
|
||||||
|
|
||||||
|
def load(self, path: Union[str, Path]):
|
||||||
|
"""Load tokenizer from directory."""
|
||||||
|
path = Path(path)
|
||||||
|
tokenizer_file = path / "tokenizer.json"
|
||||||
|
config_file = path / "tokenizer_config.json"
|
||||||
|
self._tokenizer = Tokenizer.from_file(str(tokenizer_file))
|
||||||
|
|
||||||
|
if config_file.exists():
|
||||||
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
if "special_tokens" in config:
|
||||||
|
self._special_token_map.update(config["special_tokens"])
|
||||||
|
|
||||||
|
# Load chat template from config
|
||||||
|
if "chat_template" in config:
|
||||||
|
self.set_chat_template(config["chat_template"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "TextTokenizer":
|
||||||
|
"""Load tokenizer from pretrained directory."""
|
||||||
|
instance = cls(path)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def save_pretrained(self, tokenizer, save_path: str):
|
||||||
|
"""
|
||||||
|
Save tokenizer to pretrained directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer: Tokenizer instance to save
|
||||||
|
save_path: Path to save the tokenizer
|
||||||
|
"""
|
||||||
|
save_path = Path(save_path)
|
||||||
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._tokenizer.save(tokenizer, save_path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_tokenizer(cls, name: str, tokenizer_class: type):
|
||||||
|
"""
|
||||||
|
Register a new tokenizer class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name to register the tokenizer class under
|
||||||
|
tokenizer_class: The tokenizer class to register
|
||||||
|
"""
|
||||||
|
cls.TOKENIZER_CLASSES[name] = tokenizer_class
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
tokens: Union[str, List[str]],
|
tokens: Union[str, List[str]],
|
||||||
out_ids: bool = True,
|
out_ids: bool = True,
|
||||||
add_special_tokens: bool = False,
|
is_pretokenized: bool = False,
|
||||||
|
add_special_tokens: bool = True,
|
||||||
) -> List:
|
) -> List:
|
||||||
pass
|
"""Encode text to tokens or token IDs."""
|
||||||
|
if self._tokenizer is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Tokenizer not initialized. Load or create a tokenizer first."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(tokens, str):
|
||||||
|
encoded = self._tokenizer.encode(
|
||||||
|
tokens,
|
||||||
|
is_pretokenized=is_pretokenized,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
)
|
||||||
|
return encoded.ids if out_ids else encoded.tokens
|
||||||
|
else:
|
||||||
|
encoded_list = self._tokenizer.encode_batch(
|
||||||
|
tokens,
|
||||||
|
is_pretokenized=is_pretokenized,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
|
||||||
|
]
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
|
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
|
||||||
pass
|
"""Decode token IDs to text."""
|
||||||
|
if self._tokenizer is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Tokenizer not initialized. Load or create a tokenizer first."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
pass
|
if self._tokenizer is None:
|
||||||
|
return 0
|
||||||
|
return self._tokenizer.get_vocab_size()
|
||||||
|
|
||||||
|
def __getattr__(self, key: str):
|
||||||
|
"""
|
||||||
|
Dynamically intercept special token attribute access.
|
||||||
|
Supports three forms:
|
||||||
|
- tokenizer.bos_token → returns string
|
||||||
|
- tokenizer.bos_token_id → returns corresponding integer ID
|
||||||
|
- tokenizer.stop_ids → returns list of corresponding integer IDs
|
||||||
|
"""
|
||||||
|
# Handle stop_ids
|
||||||
|
if key == "stop_ids":
|
||||||
|
return [
|
||||||
|
self._special_token_map.get(val)
|
||||||
|
for val in self._special_token_map.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Handle _id suffix (e.g., bos_token_id -> bos_token)
|
||||||
|
if key.endswith("_id"):
|
||||||
|
base_attr = key[:-3] # Remove "_id"
|
||||||
|
token_str = self._special_token_map.get(base_attr)
|
||||||
|
if token_str is None:
|
||||||
|
return None
|
||||||
|
if self._tokenizer is None:
|
||||||
|
raise RuntimeError("Tokenizer not loaded, cannot convert token to id.")
|
||||||
|
return self._tokenizer.token_to_id(token_str)
|
||||||
|
|
||||||
|
# Handle regular string attributes
|
||||||
|
if key in self._special_token_map:
|
||||||
|
return self._special_token_map.get(key)
|
||||||
|
|
||||||
|
# Other attributes trigger default AttributeError
|
||||||
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
def vocab_size(self) -> int:
|
||||||
def stop_ids(self) -> List[int]:
|
return len(self)
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
def pad_id(self) -> Optional[int]:
|
||||||
def bos_id(self) -> int:
|
"""Return the pad token ID if available."""
|
||||||
pass
|
pad_token = self._special_token_map.get("pad")
|
||||||
|
if pad_token is None or self._tokenizer is None:
|
||||||
|
return None
|
||||||
|
return self._tokenizer.token_to_id(pad_token)
|
||||||
|
|
||||||
@property
|
def set_chat_template(self, template: Union[str, ChatTemplate]):
|
||||||
@abstractmethod
|
"""
|
||||||
def eos_id(self) -> int:
|
Set the chat template for the tokenizer.
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
Args:
|
||||||
@abstractmethod
|
template: Either a template name (str) registered in the global registry,
|
||||||
def pad_id(self) -> int:
|
or a ChatTemplate instance, or a Jinja2 template string.
|
||||||
pass
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If template name is not registered.
|
||||||
|
"""
|
||||||
|
if isinstance(template, str):
|
||||||
|
self._chat_template = ChatTemplate.from_string(template)
|
||||||
|
elif isinstance(template, ChatTemplate):
|
||||||
|
self._chat_template = template
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid template type, must be str or ChatTemplate.")
|
||||||
|
|
||||||
|
def apply_chat_template(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
tokenize: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[str, List[int]]:
|
||||||
|
"""
|
||||||
|
Apply the chat template to messages and optionally tokenize the result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dicts with 'role' and 'content'.
|
||||||
|
system_prompt: Optional system prompt string.
|
||||||
|
tokenize: Whether to return token IDs (True) or raw string (False).
|
||||||
|
**kwargs: Additional variables to pass to the template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either the rendered string or list of token IDs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If chat template is not set.
|
||||||
|
"""
|
||||||
|
if self._chat_template is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Chat template not set. Use set_chat_template() to set a template first."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Render the template
|
||||||
|
rendered = self._chat_template.render(
|
||||||
|
messages=messages,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tokenize:
|
||||||
|
return self.encode(rendered)
|
||||||
|
|
||||||
|
return rendered
|
||||||
|
|
||||||
|
|
||||||
class BaseTrainer(ABC):
|
class BpeTokenizer(TextTokenizer):
|
||||||
def __init__(self, tokenizer: BaseTokenizer):
|
"""BPE tokenizer implementation."""
|
||||||
self.tokenizer = tokenizer
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def train(self, files, vocab_size, min_freq, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def train_from_iterator(self, iterator, vocab_size, min_freq, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BpeTokenizer(BaseTokenizer):
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
control_tokens: List[str] = None,
|
special_token_map: Dict[str, str] = None,
|
||||||
special_tokens: List[str] = None,
|
path: Optional[str] = None,
|
||||||
path=None,
|
chat_template: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self._control_tokens = control_tokens or [
|
special_token_map = special_token_map or {
|
||||||
"<|begin▁of▁sentence|>",
|
"bos": "<|begin▁of▁sentence|>",
|
||||||
"<|end▁of▁sentence|>",
|
"eos": "<|end▁of▁sentence|>",
|
||||||
"<|▁pad▁|>",
|
"pad": "<|▁pad▁|>",
|
||||||
]
|
"im_start": "<|im▁start|>",
|
||||||
self._special_tokens = special_tokens or [
|
"im_end": "<|im▁end|>",
|
||||||
"<|im▁start|>",
|
}
|
||||||
"<|im▁end|>",
|
|
||||||
]
|
|
||||||
self._tokenizer = None
|
self._tokenizer = None
|
||||||
self._init_tokenizer()
|
self._init_tokenizer()
|
||||||
if path is not None:
|
super().__init__(
|
||||||
self.load(path)
|
path, special_token_map=special_token_map, chat_template=chat_template
|
||||||
|
)
|
||||||
|
|
||||||
def _init_tokenizer(self):
|
def _init_tokenizer(self):
|
||||||
|
"""Initialize a new BPE tokenizer with default settings."""
|
||||||
model = BPE()
|
model = BPE()
|
||||||
self._tokenizer = Tokenizer(model)
|
self._tokenizer = Tokenizer(model)
|
||||||
self._tokenizer.normalizer = normalizers.Sequence(
|
self._tokenizer.normalizer = normalizers.Sequence(
|
||||||
|
|
@ -105,108 +262,3 @@ class BpeTokenizer(BaseTokenizer):
|
||||||
)
|
)
|
||||||
self._tokenizer.decoder = decoders.ByteLevel()
|
self._tokenizer.decoder = decoders.ByteLevel()
|
||||||
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
|
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
|
||||||
|
|
||||||
def save(self, path):
|
|
||||||
self._tokenizer.save(path)
|
|
||||||
|
|
||||||
def load(self, path):
|
|
||||||
self._tokenizer = Tokenizer.from_file(path)
|
|
||||||
|
|
||||||
def encode(
|
|
||||||
self,
|
|
||||||
tokens: Union[str, List[str]],
|
|
||||||
out_ids: bool = True,
|
|
||||||
add_special_tokens: bool = False,
|
|
||||||
) -> List:
|
|
||||||
if isinstance(tokens, str):
|
|
||||||
encoded = self._tokenizer.encode(
|
|
||||||
tokens, add_special_tokens=add_special_tokens
|
|
||||||
)
|
|
||||||
return encoded.ids if out_ids else encoded.tokens
|
|
||||||
else:
|
|
||||||
encoded_list = self._tokenizer.encode_batch(
|
|
||||||
tokens, add_special_tokens=add_special_tokens
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
|
|
||||||
]
|
|
||||||
|
|
||||||
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
|
|
||||||
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return self._tokenizer.get_vocab_size()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stop_ids(self) -> List[int]:
|
|
||||||
stop_token = self._control_tokens + self._special_tokens
|
|
||||||
return [self._tokenizer.token_to_id(tok) for tok in stop_token]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bos_id(self) -> int:
|
|
||||||
return self._tokenizer.token_to_id(self._control_tokens[0])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def eos_id(self) -> int:
|
|
||||||
return self._tokenizer.token_to_id(self._control_tokens[1])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pad_id(self) -> int:
|
|
||||||
return self._tokenizer.token_to_id(self._control_tokens[2])
|
|
||||||
|
|
||||||
|
|
||||||
class BpeTrainer(BaseTrainer):
|
|
||||||
def __init__(self, tokenizer: BaseTokenizer):
|
|
||||||
super().__init__(tokenizer)
|
|
||||||
|
|
||||||
def _prepare_trainer(
|
|
||||||
self,
|
|
||||||
vocab_size: int,
|
|
||||||
min_freq: int,
|
|
||||||
reserved_token_size: int,
|
|
||||||
max_token_length=18,
|
|
||||||
):
|
|
||||||
assert reserved_token_size > len(self.tokenizer._special_tokens)
|
|
||||||
reserved_tokens = [
|
|
||||||
f"<|reserve{i:02d}|>"
|
|
||||||
for i in range(reserved_token_size - len(self.tokenizer._special_tokens))
|
|
||||||
]
|
|
||||||
detail_vocab_size = vocab_size - (
|
|
||||||
len(reserved_tokens) + len(self.tokenizer._special_tokens)
|
|
||||||
)
|
|
||||||
alphabet = pre_tokenizers.ByteLevel.alphabet()
|
|
||||||
min_size = len(alphabet) + len(self.tokenizer._control_tokens)
|
|
||||||
assert detail_vocab_size > min_size
|
|
||||||
|
|
||||||
trainer = BpeTrainerImpl(
|
|
||||||
vocab_size=detail_vocab_size,
|
|
||||||
min_frequency=min_freq,
|
|
||||||
limit_alphabet=detail_vocab_size // 6,
|
|
||||||
max_token_length=max_token_length,
|
|
||||||
special_tokens=self.tokenizer._control_tokens,
|
|
||||||
initial_alphabet=alphabet,
|
|
||||||
show_progress=True,
|
|
||||||
)
|
|
||||||
return trainer, reserved_tokens
|
|
||||||
|
|
||||||
def train(self, files, vocab_size, min_freq, reserved_token_size=100, **kwargs):
|
|
||||||
trainer, reserved_tokens = self._prepare_trainer(
|
|
||||||
vocab_size, min_freq, reserved_token_size, **kwargs
|
|
||||||
)
|
|
||||||
self.tokenizer._tokenizer.train(files=files, trainer=trainer)
|
|
||||||
self.tokenizer._tokenizer.add_special_tokens(
|
|
||||||
self.tokenizer._special_tokens + reserved_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
def train_from_iterator(
|
|
||||||
self, iterator, vocab_size, min_freq, reserved_token_size=100, **kwargs
|
|
||||||
):
|
|
||||||
trainer, reserved_tokens = self._prepare_trainer(
|
|
||||||
vocab_size, min_freq, reserved_token_size, **kwargs
|
|
||||||
)
|
|
||||||
self.tokenizer._tokenizer.train_from_iterator(
|
|
||||||
iterator=iterator, trainer=trainer
|
|
||||||
)
|
|
||||||
self.tokenizer._tokenizer.add_special_tokens(
|
|
||||||
self.tokenizer._special_tokens + reserved_tokens
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,108 @@
|
||||||
|
"""
|
||||||
|
BPE Tokenizer Trainer module.
|
||||||
|
|
||||||
|
Provides training functionality for BPE tokenizers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from tokenizers import pre_tokenizers
|
||||||
|
from tokenizers.trainers import BpeTrainer as BpeTrainerImpl
|
||||||
|
|
||||||
|
|
||||||
|
class BpeTrainer:
|
||||||
|
"""BPE tokenizer trainer."""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer):
|
||||||
|
"""Initialize trainer with a tokenizer instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer: A BpeTokenizer instance
|
||||||
|
"""
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def _prepare_trainer(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
min_freq: int,
|
||||||
|
reserved_token_size: int,
|
||||||
|
max_token_length: int = 18,
|
||||||
|
):
|
||||||
|
"""Prepare the BPE trainer with proper configuration."""
|
||||||
|
assert reserved_token_size > len(self.tokenizer._special_tokens)
|
||||||
|
reserved_tokens = [
|
||||||
|
f"<|reserve{i:02d}|>"
|
||||||
|
for i in range(reserved_token_size - len(self.tokenizer._special_tokens))
|
||||||
|
]
|
||||||
|
detail_vocab_size = vocab_size - (
|
||||||
|
len(reserved_tokens) + len(self.tokenizer._special_tokens)
|
||||||
|
)
|
||||||
|
alphabet = pre_tokenizers.ByteLevel.alphabet()
|
||||||
|
min_size = len(alphabet) + len(self.tokenizer._control_tokens)
|
||||||
|
assert detail_vocab_size > min_size
|
||||||
|
|
||||||
|
trainer = BpeTrainerImpl(
|
||||||
|
vocab_size=detail_vocab_size,
|
||||||
|
min_frequency=min_freq,
|
||||||
|
limit_alphabet=detail_vocab_size // 6,
|
||||||
|
max_token_length=max_token_length,
|
||||||
|
special_tokens=self.tokenizer._control_tokens,
|
||||||
|
initial_alphabet=alphabet,
|
||||||
|
show_progress=True,
|
||||||
|
)
|
||||||
|
return trainer, reserved_tokens
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self,
|
||||||
|
files: Union[str, List[str]],
|
||||||
|
vocab_size: int,
|
||||||
|
min_freq: int,
|
||||||
|
reserved_token_size: int = 100,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Train tokenizer from files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: Path or list of paths to training files
|
||||||
|
vocab_size: Target vocabulary size
|
||||||
|
min_freq: Minimum frequency for tokens
|
||||||
|
reserved_token_size: Number of reserved tokens
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
"""
|
||||||
|
trainer, reserved_tokens = self._prepare_trainer(
|
||||||
|
vocab_size, min_freq, reserved_token_size, **kwargs
|
||||||
|
)
|
||||||
|
self.tokenizer._tokenizer.train(files=files, trainer=trainer)
|
||||||
|
self.tokenizer._tokenizer.add_special_tokens(
|
||||||
|
self.tokenizer._special_tokens + reserved_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_from_iterator(
|
||||||
|
self,
|
||||||
|
iterator,
|
||||||
|
vocab_size: int,
|
||||||
|
min_freq: int,
|
||||||
|
reserved_token_size: int = 100,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Train tokenizer from iterator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
iterator: Iterator yielding training strings
|
||||||
|
vocab_size: Target vocabulary size
|
||||||
|
min_freq: Minimum frequency for tokens
|
||||||
|
reserved_token_size: Number of reserved tokens
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
"""
|
||||||
|
trainer, reserved_tokens = self._prepare_trainer(
|
||||||
|
vocab_size, min_freq, reserved_token_size, **kwargs
|
||||||
|
)
|
||||||
|
self.tokenizer._tokenizer.train_from_iterator(
|
||||||
|
iterator=iterator, trainer=trainer
|
||||||
|
)
|
||||||
|
self.tokenizer._tokenizer.add_special_tokens(
|
||||||
|
self.tokenizer._special_tokens + reserved_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BpeTrainer"]
|
||||||
|
|
@ -2,24 +2,32 @@ from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
|
||||||
from astrai.inference import InferenceEngine
|
from astrai.inference import InferenceEngine
|
||||||
|
from astrai.model import AutoModel
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
||||||
|
|
||||||
def generate_text():
|
def generate_text():
|
||||||
param = ModelParameter.load(PARAMETER_ROOT, disable_init=True)
|
# Load model from pretrained
|
||||||
param.to(device="cuda", dtype=torch.bfloat16)
|
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
||||||
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Load tokenizer from pretrained
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT / "tokenizer")
|
||||||
|
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
|
|
||||||
engine = InferenceEngine(param)
|
engine = InferenceEngine(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
response = engine.generate(
|
response = engine.generate(
|
||||||
prompt=query,
|
prompt=query,
|
||||||
stream=False,
|
stream=False,
|
||||||
max_tokens=param.config.max_len,
|
max_tokens=2048,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.model import AutoModel
|
||||||
from astrai.inference import InferenceEngine
|
from astrai.inference import InferenceEngine
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
|
@ -10,8 +10,10 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
||||||
|
|
||||||
def batch_generate():
|
def batch_generate():
|
||||||
param = ModelParameter.load(PARAMETER_ROOT, disable_init=True)
|
# Load model using AutoModel
|
||||||
param.to(device="cuda", dtype=torch.bfloat16)
|
model = AutoModel.from_pretrained(
|
||||||
|
PARAMETER_ROOT, device="cuda", dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
"你好",
|
"你好",
|
||||||
|
|
@ -21,11 +23,14 @@ def batch_generate():
|
||||||
"请问什么是显卡",
|
"请问什么是显卡",
|
||||||
]
|
]
|
||||||
|
|
||||||
engine = InferenceEngine(param)
|
engine = InferenceEngine(
|
||||||
|
model=model.model,
|
||||||
|
tokenizer=model.tokenizer,
|
||||||
|
)
|
||||||
responses = engine.generate(
|
responses = engine.generate(
|
||||||
prompt=inputs,
|
prompt=inputs,
|
||||||
stream=False,
|
stream=False,
|
||||||
max_tokens=param.config.max_len,
|
max_tokens=model.config.max_len,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
|
|
|
||||||
|
|
@ -1,32 +1,39 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
|
||||||
from astrai.inference import InferenceEngine
|
from astrai.inference import InferenceEngine
|
||||||
|
from astrai.model import AutoModel
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
||||||
|
|
||||||
def chat():
|
def chat():
|
||||||
param = ModelParameter.load(PARAMETER_ROOT, disable_init=True)
|
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
||||||
param.to(device="cuda", dtype=torch.bfloat16)
|
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
||||||
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
history = []
|
messages = []
|
||||||
engine = InferenceEngine(param)
|
engine = InferenceEngine(model=model, tokenizer=tokenizer)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
if query == "!exit":
|
if query == "!exit":
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Add user message
|
||||||
|
messages.append({"role": "user", "content": query})
|
||||||
|
|
||||||
|
# Generate response
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
|
||||||
for token in engine.generate(
|
for token in engine.generate(
|
||||||
prompt=query,
|
prompt=prompt,
|
||||||
stream=True,
|
stream=True,
|
||||||
max_tokens=param.config.max_len,
|
max_tokens=model.config.max_len,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
|
|
@ -35,7 +42,8 @@ def chat():
|
||||||
full_response += token
|
full_response += token
|
||||||
|
|
||||||
print()
|
print()
|
||||||
history.append((query, full_response.strip()))
|
# Add assistant response to messages
|
||||||
|
messages.append({"role": "assistant", "content": full_response.strip()})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,8 @@ import json
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.model import AutoModel
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
from astrai.inference import InferenceEngine
|
from astrai.inference import InferenceEngine
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -17,9 +18,9 @@ def processor(
|
||||||
question_key: str,
|
question_key: str,
|
||||||
response_key: str,
|
response_key: str,
|
||||||
):
|
):
|
||||||
param = ModelParameter.load(model_dir, disable_init=True)
|
# Load model using AutoModel
|
||||||
param.to(device="cuda", dtype=torch.bfloat16)
|
model = AutoModel.from_pretrained(model_dir, device="cuda", dtype=torch.bfloat16)
|
||||||
engine = InferenceEngine(param)
|
engine = InferenceEngine(model=model.model, tokenizer=model.tokenizer)
|
||||||
|
|
||||||
with open(input_json_file, "r", encoding="utf-8") as f:
|
with open(input_json_file, "r", encoding="utf-8") as f:
|
||||||
input_data = [json.loads(line) for line in f]
|
input_data = [json.loads(line) for line in f]
|
||||||
|
|
@ -29,7 +30,7 @@ def processor(
|
||||||
responses = engine.generate(
|
responses = engine.generate(
|
||||||
prompt=queries,
|
prompt=queries,
|
||||||
stream=False,
|
stream=False,
|
||||||
max_tokens=param.config.max_len,
|
max_tokens=model.config.max_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||||
import tqdm
|
import tqdm
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.model import AutoModel
|
||||||
|
|
||||||
|
|
||||||
def compute_perplexity(
|
def compute_perplexity(
|
||||||
|
|
@ -20,7 +20,7 @@ def compute_perplexity(
|
||||||
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
|
where PPL = exp(-(1/N) * sum(log P(w_i | w_<i))).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output = model(input_ids, input_mask)
|
output = model(input_ids, input_mask=input_mask)
|
||||||
logits = output["logits"]
|
logits = output["logits"]
|
||||||
|
|
||||||
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
shifted_logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
||||||
|
|
@ -42,10 +42,9 @@ def compute_perplexity(
|
||||||
def process_file(
|
def process_file(
|
||||||
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
||||||
):
|
):
|
||||||
param = ModelParameter.load(model_dir, disable_init=True)
|
# Load model using AutoModel
|
||||||
param.to(device="cuda", dtype=torch.bfloat16)
|
model = AutoModel.from_pretrained(model_dir, device="cuda", dtype=torch.bfloat16)
|
||||||
model = param.model
|
tokenizer = model.tokenizer
|
||||||
tokenizer = param.tokenizer
|
|
||||||
|
|
||||||
with open(input_file, "r", encoding="utf-8") as f:
|
with open(input_file, "r", encoding="utf-8") as f:
|
||||||
input_data = [json.loads(line) for line in f]
|
input_data = [json.loads(line) for line in f]
|
||||||
|
|
@ -54,7 +53,7 @@ def process_file(
|
||||||
encoded_texts = [tokenizer.encode(text) for text in texts]
|
encoded_texts = [tokenizer.encode(text) for text in texts]
|
||||||
output_data = []
|
output_data = []
|
||||||
|
|
||||||
for i in tqdm(
|
for i in tqdm.tqdm(
|
||||||
range(0, len(encoded_texts), batch_size), desc="Computing perplexity"
|
range(0, len(encoded_texts), batch_size), desc="Computing perplexity"
|
||||||
):
|
):
|
||||||
batch_encoded = encoded_texts[i : i + batch_size]
|
batch_encoded = encoded_texts[i : i + batch_size]
|
||||||
|
|
@ -72,7 +71,7 @@ def process_file(
|
||||||
|
|
||||||
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
|
input_ids = torch.tensor(padded_ids, device="cuda", dtype=torch.long)
|
||||||
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
|
input_mask = torch.tensor(masks, device="cuda", dtype=torch.bool)
|
||||||
perplexity = compute_perplexity(model, input_ids, input_mask)
|
perplexity = compute_perplexity(model.model, input_ids, input_mask)
|
||||||
|
|
||||||
for text, ppl in zip(batch_texts, perplexity):
|
for text, ppl in zip(batch_texts, perplexity):
|
||||||
output_data.append({text_key: text, "ppl": float(ppl.item())})
|
output_data.append({text_key: text, "ppl": float(ppl.item())})
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,12 @@ from functools import partial
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
import safetensors.torch as st
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from astrai.config import ModelParameter, TrainConfig
|
from astrai.config import ModelConfig, TrainConfig
|
||||||
from astrai.dataset import DatasetFactory
|
from astrai.dataset import DatasetFactory
|
||||||
|
from astrai.model import Transformer
|
||||||
from astrai.parallel import get_rank
|
from astrai.parallel import get_rank
|
||||||
from astrai.trainer import SchedulerFactory, Trainer
|
from astrai.trainer import SchedulerFactory, Trainer
|
||||||
|
|
||||||
|
|
@ -196,12 +198,23 @@ def train(
|
||||||
assert train_type in ["seq", "sft", "dpo"]
|
assert train_type in ["seq", "sft", "dpo"]
|
||||||
assert os.path.exists(param_path)
|
assert os.path.exists(param_path)
|
||||||
|
|
||||||
parameter = ModelParameter.load(param_path)
|
# Load config
|
||||||
|
config = ModelConfig()
|
||||||
|
config_path = os.path.join(param_path, "config.json")
|
||||||
|
if os.path.exists(config_path):
|
||||||
|
config.load(config_path)
|
||||||
|
|
||||||
if window_size is None:
|
if window_size is None:
|
||||||
window_size = parameter.config.max_len
|
window_size = config.max_len
|
||||||
|
|
||||||
model = parameter.model
|
# Create bare Transformer (for training, no tokenizer needed)
|
||||||
|
model = Transformer(config)
|
||||||
|
|
||||||
|
# Load weights if available
|
||||||
|
weights_path = os.path.join(param_path, "model.safetensors")
|
||||||
|
if os.path.exists(weights_path):
|
||||||
|
state_dict = st.load_file(weights_path)
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
|
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from unittest.mock import MagicMock
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from astrai.inference.server import app, _engine
|
from astrai.inference.server import app
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -1,34 +0,0 @@
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_parameter(test_env):
|
|
||||||
save_dir = os.path.join(test_env["test_dir"], "save")
|
|
||||||
model_param = ModelParameter(
|
|
||||||
test_env["model"], test_env["tokenizer"], test_env["transformer_config"]
|
|
||||||
)
|
|
||||||
ModelParameter.save(model_param, save_dir)
|
|
||||||
|
|
||||||
assert os.path.exists(os.path.join(save_dir, "model.safetensors"))
|
|
||||||
assert os.path.exists(os.path.join(save_dir, "tokenizer.json"))
|
|
||||||
assert os.path.exists(os.path.join(save_dir, "config.json"))
|
|
||||||
|
|
||||||
|
|
||||||
# transformer
|
|
||||||
def test_transformer(test_env):
|
|
||||||
model = test_env["model"]
|
|
||||||
input_ids = torch.randint(
|
|
||||||
0,
|
|
||||||
test_env["transformer_config"].vocab_size,
|
|
||||||
(4, test_env["transformer_config"].max_len),
|
|
||||||
)
|
|
||||||
output_logits = model(input_ids)["logits"]
|
|
||||||
target_shape = (
|
|
||||||
4,
|
|
||||||
test_env["transformer_config"].max_len,
|
|
||||||
test_env["transformer_config"].vocab_size,
|
|
||||||
)
|
|
||||||
assert output_logits.shape == target_shape
|
|
||||||
Loading…
Reference in New Issue