From fc278d17ab8a2d5c70e587555232c0b2773645a8 Mon Sep 17 00:00:00 2001
From: ViperEkura <3081035982@qq.com>
Date: Sun, 5 Apr 2026 19:38:12 +0800
Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E6=A8=A1=E5=9E=8B?=
=?UTF-8?q?=E5=8A=A8=E6=80=81=E6=B3=A8=E5=86=8C=E6=9C=BA=E5=88=B6?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.gitattributes | 1 +
assets/docs/params.md | 8 +-
astrai/__init__.py | 3 +-
astrai/config/__init__.py | 4 -
astrai/config/model_config.py | 1 +
astrai/config/param_config.py | 114 ---------
astrai/factory.py | 3 +
astrai/inference/engine.py | 54 +++--
astrai/inference/scheduler.py | 33 +--
astrai/inference/server.py | 76 +++---
astrai/model/__init__.py | 14 +-
astrai/model/automodel.py | 134 +++++++++++
astrai/model/transformer.py | 30 ++-
astrai/tokenize/__init__.py | 17 +-
astrai/tokenize/chat_template.py | 211 -----------------
astrai/tokenize/tokenizer.py | 384 ++++++++++++++++++-------------
astrai/tokenize/trainer.py | 108 +++++++++
scripts/demo/generate_ar.py | 18 +-
scripts/demo/generate_batch.py | 15 +-
scripts/demo/stream_chat.py | 26 ++-
scripts/tools/generate.py | 11 +-
scripts/tools/perplexity.py | 15 +-
scripts/tools/train.py | 21 +-
tests/inference/conftest.py | 2 +-
tests/module/test_module.py | 34 ---
25 files changed, 686 insertions(+), 651 deletions(-)
delete mode 100644 astrai/config/param_config.py
create mode 100644 astrai/model/automodel.py
create mode 100644 astrai/tokenize/trainer.py
delete mode 100644 tests/module/test_module.py
diff --git a/.gitattributes b/.gitattributes
index 1de0c60..60472b8 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -3,6 +3,7 @@
# Files that MUST use LF (Unix/Linux execution)
*.sh text eol=lf
+*.py text eol=lf
Dockerfile text eol=lf
.dockerignore text eol=lf
diff --git a/assets/docs/params.md b/assets/docs/params.md
index 9d4bf4b..6c6f550 100644
--- a/assets/docs/params.md
+++ b/assets/docs/params.md
@@ -90,8 +90,12 @@ from astrai.inference import InferenceEngine, GenerationRequest
param = ModelParameter.load("your_model_dir")
param.to(device="cuda", dtype=torch.bfloat16)
-# Create engine
-engine = InferenceEngine(param)
+# Create engine with separate model and tokenizer
+engine = InferenceEngine(
+ model=param.model,
+ tokenizer=param.tokenizer,
+ config=param.config,
+)
# Build request
request = GenerationRequest(
diff --git a/astrai/__init__.py b/astrai/__init__.py
index 423ab84..b98a372 100644
--- a/astrai/__init__.py
+++ b/astrai/__init__.py
@@ -12,7 +12,7 @@ from astrai.inference import (
GenerationRequest,
InferenceEngine,
)
-from astrai.model.transformer import Transformer
+from astrai.model import AutoModel, Transformer
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
__all__ = [
@@ -27,4 +27,5 @@ __all__ = [
"StrategyFactory",
"SchedulerFactory",
"BaseFactory",
+ "AutoModel",
]
diff --git a/astrai/config/__init__.py b/astrai/config/__init__.py
index 7b90769..bd47b59 100644
--- a/astrai/config/__init__.py
+++ b/astrai/config/__init__.py
@@ -1,11 +1,7 @@
from astrai.config.model_config import ModelConfig
-from astrai.config.param_config import BaseModelIO, ModelParameter
from astrai.config.train_config import TrainConfig
__all__ = [
- # Base I/O
- "BaseModelIO",
- "ModelParameter",
# Model configuration
"ModelConfig",
"TrainConfig",
diff --git a/astrai/config/model_config.py b/astrai/config/model_config.py
index 9789cd1..cf1ebb8 100644
--- a/astrai/config/model_config.py
+++ b/astrai/config/model_config.py
@@ -6,6 +6,7 @@ from typing import Optional, Self
@dataclass
class ModelConfig:
# basic config
+ model_type: Optional[str] = None
vocab_size: Optional[int] = None
dim: Optional[int] = None
diff --git a/astrai/config/param_config.py b/astrai/config/param_config.py
deleted file mode 100644
index 9f15acf..0000000
--- a/astrai/config/param_config.py
+++ /dev/null
@@ -1,114 +0,0 @@
-from contextlib import contextmanager
-from dataclasses import dataclass, field
-from pathlib import Path
-from typing import Self, Union
-
-import safetensors.torch as st
-import torch.nn as nn
-
-from astrai.config.model_config import ModelConfig
-from astrai.tokenize import BpeTokenizer
-from astrai.model.transformer import Transformer
-
-
-@contextmanager
-def disable_random_init(enable: bool = True):
- init_functions = [
- "xavier_normal_",
- "xavier_uniform_",
- "kaiming_normal_",
- "kaiming_uniform_",
- "zeros_",
- "ones_",
- "constant_",
- "normal_",
- "uniform_",
- ]
- original_funcs = {}
- for name in init_functions:
- if enable and hasattr(nn.init, name):
- original_funcs[name] = getattr(nn.init, name)
- setattr(nn.init, name, lambda *args, **kwargs: None)
- try:
- yield
- finally:
- if enable:
- for name, orig_func in original_funcs.items():
- setattr(nn.init, name, orig_func)
-
-
-@dataclass
-class BaseModelIO:
- """Base class for model I/O operations."""
-
- model: nn.Module = field(
- default_factory=nn.Identity, metadata={"help": "Transformer model."}
- )
- tokenizer: BpeTokenizer = field(
- default_factory=BpeTokenizer, metadata={"help": "Tokenizer for the model."}
- )
- config: ModelConfig = field(
- default_factory=ModelConfig,
- metadata={"help": "Transformer model configuration."},
- )
-
- def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
- """Get standardized file paths for model components."""
- dir_path = Path(directory)
- return {
- "model": dir_path / "model.safetensors",
- "config": dir_path / "config.json",
- "tokenizer": dir_path / "tokenizer.json",
- }
-
- def save_components(self, save_dir: Union[str, Path]):
- """Save core model components."""
- paths = self._get_file_paths(save_dir)
- paths["model"].parent.mkdir(parents=True, exist_ok=True)
-
- if self.model is not None:
- st.save_file(self.model.state_dict(), str(paths["model"]))
-
- self.config.save(str(paths["config"]))
- self.tokenizer.save(str(paths["tokenizer"]))
-
- def load_components(
- self, load_dir: Union[str, Path], disable_init: bool = False
- ) -> Self:
- """Load core model components."""
- paths = self._get_file_paths(load_dir)
-
- self.config.load(str(paths["config"]))
- self.tokenizer.load(str(paths["tokenizer"]))
-
- if isinstance(self.model, nn.Identity):
- with disable_random_init(enable=disable_init):
- self.model = Transformer(self.config)
-
- if paths["model"].exists():
- state_dict = st.load_file(str(paths["model"]))
- self.model.load_state_dict(state_dict)
-
- return self
-
- def to(self, *args, **kwargs) -> "BaseModelIO":
- """Move model to device."""
- if self.model is not None:
- self.model.to(*args, **kwargs)
- return self
-
-
-@dataclass
-class ModelParameter(BaseModelIO):
- """Container for model parameters with serialization capabilities."""
-
- @classmethod
- def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]):
- instance.save_components(save_dir)
-
- @classmethod
- def load(
- cls, load_dir: Union[str, Path], disable_init: bool = False
- ) -> "ModelParameter":
- instance = cls()
- return instance.load_components(load_dir, disable_init=disable_init)
diff --git a/astrai/factory.py b/astrai/factory.py
index 6c88c15..2109113 100644
--- a/astrai/factory.py
+++ b/astrai/factory.py
@@ -185,3 +185,6 @@ class BaseFactory(ABC, Generic[T]):
def list_by_priority(cls, reverse: bool = False) -> List[str]:
"""List registered component names sorted by priority."""
return cls._registry.list_by_priority(reverse)
+
+
+__all__ = ["Registry", "BaseFactory"]
diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py
index 5c86c53..14cc073 100644
--- a/astrai/inference/engine.py
+++ b/astrai/inference/engine.py
@@ -1,11 +1,11 @@
"""Unified inference engine."""
import threading
+import torch
+import torch.nn as nn
from typing import Any, Dict, Generator, List, Optional, Union
-from astrai.config import ModelParameter
-from astrai.tokenize.chat_template import build_prompt
-
+from astrai.tokenize.tokenizer import TextTokenizer
from astrai.inference.scheduler import InferenceScheduler
@@ -14,22 +14,18 @@ class GenerationRequest:
def __init__(
self,
- query: Union[str, List[str]],
+ messages: List[Dict[str, str]],
top_k: int = 50,
top_p: float = 1.0,
temperature: float = 1.0,
max_len: int = 1024,
- history: Optional[Any] = None,
- system_prompt: Optional[str] = None,
stream: bool = False,
):
- self.query = query
+ self.messages = messages
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.max_len = max_len
- self.history = history
- self.system_prompt = system_prompt
self.stream = stream
self._validate()
@@ -107,26 +103,41 @@ class InferenceEngine:
def __init__(
self,
- parameter: ModelParameter,
- max_batch_size: int = 16,
+ model: nn.Module,
+ tokenizer: TextTokenizer,
+ max_batch_size: int = 1,
max_seq_len: Optional[int] = None,
):
- self.model = parameter.model
- self.tokenizer = parameter.tokenizer
- self.config = parameter.config
+ """
+ Initialize inference engine with separate model and tokenizer.
- model_params = next(self.model.parameters())
- self.device = model_params.device
- self.dtype = model_params.dtype
+ Args:
+ model: The language model for inference (nn.Module, e.g., Transformer)
+ tokenizer: The tokenizer for encoding/decoding text
+ config: Model configuration
+ max_batch_size: Maximum batch size for continuous batching
+ max_seq_len: Maximum sequence length (defaults to config.max_len)
+ """
+ self.model = model
+ self.tokenizer = tokenizer
+
+ # Get device and dtype from model parameters
+ try:
+ first_param = next(model.parameters())
+ device = first_param.device
+ dtype = first_param.dtype
+ except StopIteration:
+ # Model has no parameters, use default device/dtype
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ dtype = torch.float32
self.scheduler = InferenceScheduler(
model=self.model,
tokenizer=self.tokenizer,
- config=self.config,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
- device=self.device,
- dtype=self.dtype,
+ device=device,
+ dtype=dtype,
)
self.kv_cache = self.scheduler.kv_cache
@@ -160,7 +171,8 @@ class InferenceEngine:
self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]:
"""Generate with GenerationRequest object."""
- prompt = build_prompt(request.query, request.history)
+ # Use tokenizer's chat template with messages
+ prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate(
prompt=prompt,
diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py
index 58ceaca..7e44b07 100644
--- a/astrai/inference/scheduler.py
+++ b/astrai/inference/scheduler.py
@@ -8,7 +8,8 @@ from typing import Any, Callable, Dict, List, Optional
import torch
from torch import Tensor
-from astrai.config import ModelConfig
+from astrai.model.automodel import AutoModel
+from astrai.tokenize.tokenizer import TextTokenizer
class TaskStatus:
@@ -98,23 +99,23 @@ class InferenceScheduler:
def __init__(
self,
- model,
- tokenizer,
- config: ModelConfig,
+ model: AutoModel,
+ tokenizer: TextTokenizer,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
+ config = model.config
+
self.model = model
self.tokenizer = tokenizer
- self.config = config
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len
- self.device = device
- self.dtype = dtype
+ self.device = device or next(model.parameters()).device
+ self.dtype = dtype or next(model.parameters()).dtype
- num_heads = config.n_kv_heads
+ num_kv_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads
n_layers = config.n_layers
@@ -123,26 +124,26 @@ class InferenceScheduler:
max_batch_size,
self.max_seq_len,
n_layers,
- num_heads,
+ num_kv_heads,
head_dim,
),
- device=device,
- dtype=dtype,
+ device=self.device,
+ dtype=self.dtype,
)
v_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
- num_heads,
+ num_kv_heads,
head_dim,
),
- device=device,
- dtype=dtype,
+ device=self.device,
+ dtype=self.dtype,
)
self.kv_cache = (k_cache, v_cache)
self.seq_mask = torch.ones(
- (max_batch_size, self.max_seq_len), device=device, dtype=torch.bool
+ (max_batch_size, self.max_seq_len), device=self.device, dtype=torch.bool
)
self.waiting_queue: List[Task] = []
@@ -259,7 +260,7 @@ class InferenceScheduler:
)
with torch.inference_mode():
- outputs = self.model(
+ self.model(
input_ids,
input_mask=input_mask,
start_pos=0,
diff --git a/astrai/inference/server.py b/astrai/inference/server.py
index c780add..aa6b801 100644
--- a/astrai/inference/server.py
+++ b/astrai/inference/server.py
@@ -3,8 +3,6 @@ Inference Server with Continuous Batching Support
FastAPI server for inference with continuous batching.
Provides OpenAI-compatible chat completion endpoints.
-
-Author: AstrAI Team
"""
import json
@@ -19,14 +17,15 @@ from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
-from astrai.config.param_config import ModelParameter
-from astrai.inference.engine import GenerationRequest, InferenceEngine
+from astrai.inference.engine import InferenceEngine
+from astrai.model import AutoModel
+from astrai.tokenize import TextTokenizer
logger = logging.getLogger(__name__)
# Global model parameter and engine (loaded once)
-_model_param: Optional[ModelParameter] = None
_engine: Optional[InferenceEngine] = None
+_model_param: Optional[Any] = None
_project_root = Path(__file__).parent.parent.parent
# Server configuration (set before running server)
@@ -95,13 +94,17 @@ def load_model(
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
- _model_param = ModelParameter.load(param_path, disable_init=True)
+
+ # Load tokenizer separately
+ tokenizer = TextTokenizer.from_pretrained(param_path)
+ _model_param = AutoModel.from_pretrained(param_path, tokenizer=tokenizer)
_model_param.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
- # Initialize inference engine with continuous batching
+ # Initialize inference engine with separate model and tokenizer
_engine = InferenceEngine(
- parameter=_model_param,
+ model=_model_param,
+ tokenizer=tokenizer,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
@@ -164,27 +167,43 @@ def convert_messages_to_history(
return system_prompt, history if history else None
-def convert_messages_to_prompt(messages: List[ChatMessage]) -> str:
+def convert_messages_to_prompt(
+ messages: List[ChatMessage], engine: InferenceEngine = None
+) -> str:
"""Convert messages to prompt string.
Args:
messages: List of ChatMessage objects
+ engine: InferenceEngine instance for accessing tokenizer
Returns:
str: Formatted prompt string
"""
- system_prompt, history = convert_messages_to_history(messages)
+ # Convert to dict format for chat template
+ msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
- # Get the last user message as query
- user_messages = [m.content for m in messages if m.role == "user"]
- if not user_messages:
- raise ValueError("No user message found")
- query = user_messages[-1]
+ # Extract system prompt if present
+ system_prompt = None
+ filtered_messages = []
+ for msg in msg_dicts:
+ if msg["role"] == "system":
+ system_prompt = msg["content"]
+ else:
+ filtered_messages.append(msg)
- # Build prompt using chat template
- from astrai.tokenize.chat_template import build_prompt
+ # Use engine's tokenizer chat template if available
+ if engine is not None and engine.tokenizer is not None:
+ return engine.tokenizer.apply_chat_template(
+ filtered_messages, system_prompt=system_prompt, tokenize=False
+ )
- return build_prompt(query, history)
+ # Fallback: simple concatenation (deprecated)
+ prompt_parts = []
+ for msg in filtered_messages:
+ prompt_parts.append(
+ f"<|im▁start|>{msg['role']}\n{msg['content']}<|im▁end|>"
+ )
+ return "\n".join(prompt_parts) + "\n<|im▁start|>assistant\n"
@app.get("/health")
@@ -213,8 +232,8 @@ async def chat_completion(request: ChatCompletionRequest):
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
- # Convert messages to prompt
- prompt = convert_messages_to_prompt(request.messages)
+ # Convert messages to prompt using engine's tokenizer
+ prompt = convert_messages_to_prompt(request.messages, engine=_engine)
if request.stream:
# Streaming response (use synchronous generator)
@@ -294,15 +313,18 @@ async def generate(
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
- # Convert history format
- hist: Optional[List[Tuple[str, str]]] = None
+ # Build messages for chat template
+ messages = []
if history:
- hist = [(h[0], h[1]) for h in history]
+ # Convert history format: List[List[str]] -> List[Dict]
+ for h in history:
+ if len(h) >= 2:
+ messages.append({"role": "user", "content": h[0]})
+ messages.append({"role": "assistant", "content": h[1]})
+ messages.append({"role": "user", "content": query})
- # Build prompt
- from astrai.tokenize.chat_template import build_prompt
-
- prompt = build_prompt(query, hist)
+ # Use tokenizer's chat template
+ prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
if stream:
# Synchronous streaming
diff --git a/astrai/model/__init__.py b/astrai/model/__init__.py
index 516ccd3..35d74cc 100644
--- a/astrai/model/__init__.py
+++ b/astrai/model/__init__.py
@@ -6,5 +6,17 @@ from astrai.model.module import (
RMSNorm,
)
from astrai.model.transformer import Transformer
+from astrai.model.automodel import AutoModel
-__all__ = ["Linear", "RMSNorm", "MLP", "GQA", "DecoderBlock", "Transformer"]
+
+__all__ = [
+ # Modules
+ "Linear",
+ "RMSNorm",
+ "MLP",
+ "GQA",
+ "DecoderBlock",
+ # Models
+ "Transformer",
+ "AutoModel",
+]
diff --git a/astrai/model/automodel.py b/astrai/model/automodel.py
new file mode 100644
index 0000000..b2cbad1
--- /dev/null
+++ b/astrai/model/automodel.py
@@ -0,0 +1,134 @@
+"""
+AutoModel base class for model loading and saving.
+"""
+
+import torch.nn as nn
+import safetensors.torch as st
+
+from pathlib import Path
+from contextlib import contextmanager
+from typing import Self, Union, Dict, Type
+
+from astrai.config import ModelConfig
+
+
+@contextmanager
+def _disable_random_init(enable: bool = True):
+ init_functions = [
+ "xavier_normal_",
+ "xavier_uniform_",
+ "kaiming_normal_",
+ "kaiming_uniform_",
+ "zeros_",
+ "ones_",
+ "constant_",
+ "normal_",
+ "uniform_",
+ ]
+ original_funcs = {}
+ for name in init_functions:
+ if enable and hasattr(nn.init, name):
+ original_funcs[name] = getattr(nn.init, name)
+ setattr(nn.init, name, lambda *args, **kwargs: None)
+ try:
+ yield
+ finally:
+ if enable:
+ for name, orig_func in original_funcs.items():
+ setattr(nn.init, name, orig_func)
+
+
+class AutoModel(nn.Module):
+ """
+ Autoregressive language model base class.
+ Provides model loading/saving and generation capabilities.
+ """
+
+ # Model registry - stored as class attribute
+ _registry: Dict[str, Type["AutoModel"]] = {}
+
+ def __init__(self, config: ModelConfig):
+ super().__init__()
+ self.config = config
+
+ @classmethod
+ def register(cls, model_type: str):
+ """
+ Class method decorator to register model type.
+
+ Usage:
+ @AutoModel.register('transformer')
+ class Transformer(AutoModel):
+ ...
+ """
+
+ def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
+ cls._registry[model_type.lower()] = sub_cls
+ return sub_cls
+
+ return decorator
+
+ @classmethod
+ def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
+ """Get model class by model_type string."""
+ model_type = model_type.lower()
+ if model_type not in cls._registry:
+ available = list(cls._registry.keys())
+ raise ValueError(
+ f"Unknown model_type: {model_type}. Available: {available}"
+ )
+ return cls._registry[model_type]
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ path: Union[str, Path],
+ disable_random_init: bool = True,
+ ) -> nn.Module:
+
+ model_path = Path(path)
+
+ # Load config
+ config = ModelConfig()
+ config_path = model_path / "config.json"
+ if config_path.exists():
+ config.load(str(config_path))
+ else:
+ raise FileNotFoundError(f"Config file not found: {config_path}")
+
+ # If called from base class, use model_type to determine actual model class
+ if cls is AutoModel:
+ model_type = config.model_type or "transformer"
+ actual_cls = cls.get_model_class(model_type)
+ else:
+ raise ValueError(
+ f"Cannot call from_pretrained() on subclass {cls.__name__}"
+ )
+
+ with _disable_random_init(enable=disable_random_init):
+ model = actual_cls(config)
+
+ # Load weights
+ weights_path = model_path / "model.safetensors"
+ if weights_path.exists():
+ state_dict = st.load_file(str(weights_path))
+ model.load_state_dict(state_dict, strict=False)
+
+ return model
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, Path],
+ ) -> None:
+ save_path = Path(save_directory)
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ # Save config
+ self.config.save(str(save_path / "config.json"))
+
+ # Save weights
+ st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
+
+ def to(self, *args, **kwargs) -> Self:
+ """Move model to device/dtype."""
+ return super().to(*args, **kwargs)
diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py
index 281cd4a..2198a4b 100644
--- a/astrai/model/transformer.py
+++ b/astrai/model/transformer.py
@@ -5,6 +5,7 @@ import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import ModelConfig
+from astrai.model.automodel import AutoModel
from astrai.model.module import (
DecoderBlock,
Embedding,
@@ -66,9 +67,14 @@ def process_attention_mask(
return attention_mask
-class Transformer(nn.Module):
+@AutoModel.register("transformer")
+class Transformer(AutoModel):
+ """
+ Transformer language model.
+ """
+
def __init__(self, config: ModelConfig):
- super().__init__()
+ super().__init__(config)
self.config = config
self.rotary_embeding = RotaryEmbedding(
config.dim // config.n_heads, config.max_len
@@ -97,16 +103,27 @@ class Transformer(nn.Module):
if self.config.tie_weight:
self.lm_head.weight = self.embed_tokens.weight
- self._init_parameters()
+ self._init_weights()
+
+ def _init_weights(self):
+ for param in self.parameters():
+ if param.dim() > 1:
+ nn.init.normal_(param, mean=0.0, std=0.006)
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
lm_head_key = "lm_head.weight"
embed_key = "embed_tokens.weight"
+ # Make a copy to avoid modifying the original state_dict
+ state_dict = dict(state_dict)
+
if self.config.tie_weight:
# same tensor
- state_dict[lm_head_key] = state_dict[embed_key]
+ if embed_key in state_dict:
+ state_dict[lm_head_key] = state_dict[embed_key]
else:
+ # If lm_head.weight exists in checkpoint, use it directly
+ # If not, copy from embed_tokens.weight
if lm_head_key not in state_dict and embed_key in state_dict:
# use clone to avoid sharing the same tensor
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
@@ -125,11 +142,6 @@ class Transformer(nn.Module):
return state_dict
- def _init_parameters(self):
- for param in self.parameters():
- if param.dim() > 1:
- nn.init.normal_(param, mean=0.0, std=0.006)
-
def forward(
self,
input_ids: Tensor,
diff --git a/astrai/tokenize/__init__.py b/astrai/tokenize/__init__.py
index 5c7fe0c..e98db15 100644
--- a/astrai/tokenize/__init__.py
+++ b/astrai/tokenize/__init__.py
@@ -1,22 +1,23 @@
from astrai.tokenize.tokenizer import (
- BaseTokenizer,
+ TextTokenizer,
BpeTokenizer,
- BaseTrainer,
- BpeTrainer,
)
+from astrai.tokenize.trainer import BpeTrainer
from astrai.tokenize.chat_template import (
+ ChatTemplate,
HistoryType,
MessageType,
- build_prompt,
)
+# Alias for compatibility
+AutoTokenizer = TextTokenizer
+
__all__ = [
- "BaseTokenizer",
+ "TextTokenizer",
+ "AutoTokenizer",
"BpeTokenizer",
- "BaseTrainer",
"BpeTrainer",
+ "ChatTemplate",
"HistoryType",
"MessageType",
- "CHAT_TEMPLATES",
- "build_prompt",
]
diff --git a/astrai/tokenize/chat_template.py b/astrai/tokenize/chat_template.py
index aa04f2f..6d4b2ee 100644
--- a/astrai/tokenize/chat_template.py
+++ b/astrai/tokenize/chat_template.py
@@ -1,7 +1,6 @@
from typing import Dict, List, Optional, Tuple, Any
from jinja2 import Template
from dataclasses import dataclass
-from astrai.factory import Registry
HistoryType = List[Tuple[str, str]]
MessageType = Dict[str, str]
@@ -75,213 +74,3 @@ class ChatTemplate:
jinja_template = Template(self.template_str)
return jinja_template.render(**variables)
-
-
-# Global registry instance
-_default_registry = Registry()
-
-# Default template name
-_default_template_name = "chatml"
-
-
-# Convenience functions
-def register_chat_template(
- name: str,
- template_str: str,
- description: str = "",
- default_variables: Optional[Dict[str, Any]] = None,
- special_tokens: Optional[Dict[str, str]] = None,
-) -> ChatTemplate:
- """Register a chat template in the global registry."""
- template = ChatTemplate(
- name=name,
- template_str=template_str,
- description=description,
- default_variables=default_variables,
- special_tokens=special_tokens,
- )
- _default_registry.register(name, template, category=None, priority=0)
- return template
-
-
-def set_default_chat_template(name: str) -> None:
- """Set the default chat template name globally."""
- global _default_template_name
- if not _default_registry.contains(name):
- raise KeyError(
- f"Chat template '{name}' not found. Available: {list(_default_registry.list_names())}"
- )
- _default_template_name = name
-
-
-def get_default_chat_template_name() -> str:
- """Get the current default chat template name."""
- return _default_template_name
-
-
-def get_chat_template(name: str) -> ChatTemplate:
- """Get a chat template from the global registry."""
- return _default_registry.get(name)
-
-
-def list_chat_templates() -> List[str]:
- """List all registered chat template names."""
- return _default_registry.list_names()
-
-
-def chat_template_exists(name: str) -> bool:
- """Check if a chat template exists."""
- return _default_registry.contains(name)
-
-
-def build_prompt(
- query: str,
- system_prompt: Optional[str] = None,
- history: Optional[HistoryType] = None,
- template: Optional[str] = None,
- template_name: Optional[str] = None,
- **extra_variables: Any,
-) -> str:
- """Build prompt using a registered chat template or a custom template string.
-
- This function maintains backward compatibility with the previous API.
-
- Args:
- query: The current user query.
- system_prompt: Optional system prompt.
- history: Optional list of (user_msg, assistant_msg) pairs.
- template: If provided, uses this exact Jinja2 template string (overrides template_name).
- template_name: Name of a registered template to use (ignored if `template` is given).
- If None, uses the globally set default template (see `set_default_chat_template`).
- **extra_variables: Additional variables to pass to the template.
-
- Returns:
- Rendered prompt string.
-
- Raises:
- KeyError: If `template_name` is not registered.
- """
- # Convert history to message format
- messages: List[MessageType] = []
- if history:
- for user_msg, assistant_msg in history:
- messages.append({"role": "user", "content": user_msg})
- messages.append({"role": "assistant", "content": assistant_msg})
- messages.append({"role": "user", "content": query})
-
- if template is not None:
- # Use the provided template string directly
- jinja_template = Template(template)
- variables = {"messages": messages, **extra_variables}
- if system_prompt is not None:
- variables["system_prompt"] = system_prompt
- return jinja_template.render(**variables)
- else:
- # Determine which template name to use
- if template_name is None:
- template_name = _default_template_name
- # Use a registered template
- chat_template = get_chat_template(template_name)
- return chat_template.render(
- messages=messages,
- system_prompt=system_prompt,
- **extra_variables,
- )
-
-
-# Predefined templates
-# ChatML template (original)
-register_chat_template(
- name="chatml",
- template_str=(
- "{%- if system_prompt -%}\n"
- "{{ bos_token }}system\n"
- "{{ system_prompt }}{{ eos_token }}\n"
- "{%- endif -%}\n"
- "{%- for message in messages -%}\n"
- "{{ bos_token }}{{ message['role'] }}\n"
- "{{ message['content'] }}{{ eos_token }}\n"
- "{%- endfor -%}\n"
- "{{ bos_token }}assistant\n"
- ),
- description="ChatML format with configurable special tokens.",
- special_tokens={"bos_token": "<|im▁start|>", "eos_token": "<|im▁end|>"},
-)
-
-# Simplified template without special tokens (plain text)
-register_chat_template(
- name="plain",
- template_str=(
- "{%- if system_prompt -%}\n"
- "System: {{ system_prompt }}\n"
- "{%- endif -%}\n"
- "{%- for message in messages -%}\n"
- "{{ message['role']|capitalize }}: {{ message['content'] }}\n"
- "{%- endfor -%}\n"
- "Assistant:"
- ),
- description="Plain text format with role labels.",
-)
-
-# Alpaca-style template
-register_chat_template(
- name="alpaca",
- template_str=(
- "{%- if system_prompt -%}\n"
- "### Instruction:\n"
- "{{ system_prompt }}\n"
- "{%- endif -%}\n"
- "### Input:\n"
- "{{ messages[-1]['content'] }}\n"
- "### Response:"
- ),
- description="Alpaca instruction‑response format (single‑turn).",
- default_variables={},
-)
-
-# OpenAI chat format (approximation)
-register_chat_template(
- name="openai",
- template_str=(
- "{%- if system_prompt -%}\n"
- "{{ bos_token }}system\n"
- "{{ system_prompt }}{{ eos_token }}\n"
- "{%- endif -%}\n"
- "{%- for message in messages -%}\n"
- "{{ bos_token }}{{ message['role'] }}\n"
- "{{ message['content'] }}{{ eos_token }}\n"
- "{%- endfor -%}\n"
- "{{ bos_token }}assistant\n"
- ),
- description="OpenAI‑compatible chat format with configurable special tokens.",
- special_tokens={"bos_token": "<|im▁start|>", "eos_token": "<|im▁end|>"},
-)
-
-# Llama‑2 style with [INST] tags
-register_chat_template(
- name="llama2",
- template_str=(
- "{%- if system_prompt -%}\n"
- "<>\n"
- "{{ system_prompt }}\n"
- "<>\n"
- "{%- endif -%}\n"
- "[INST] {{ messages[-1]['content'] }} [/INST]"
- ),
- description="Llama‑2 style with [INST] tags (single‑turn).",
- default_variables={},
-)
-
-
-__all__ = [
- "ChatTemplate",
- "register_chat_template",
- "get_chat_template",
- "list_chat_templates",
- "chat_template_exists",
- "build_prompt",
- "set_default_chat_template",
- "get_default_chat_template_name",
- "HistoryType",
- "MessageType",
-]
diff --git a/astrai/tokenize/tokenizer.py b/astrai/tokenize/tokenizer.py
index 96cab47..c6e2fc7 100644
--- a/astrai/tokenize/tokenizer.py
+++ b/astrai/tokenize/tokenizer.py
@@ -1,97 +1,254 @@
-from abc import ABC, abstractmethod
-from typing import List, Union
+"""
+Tokenizer module with implementation and auto-loading support.
+"""
+
+import json
+from pathlib import Path
+from typing import Dict, List, Optional, Union
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE
-from tokenizers.trainers import BpeTrainer as BpeTrainerImpl
+from astrai.tokenize.chat_template import ChatTemplate
-class BaseTokenizer(ABC):
- @abstractmethod
- def _init_tokenizer(self):
- pass
+class TextTokenizer:
+ """Base tokenizer class with automatic loading support"""
- @abstractmethod
- def save(self, path):
- pass
+ TOKENIZER_CLASSES = {} # Registry for auto-loading
- @abstractmethod
- def load(self, path):
- pass
+ def __init__(
+ self,
+ path: Optional[Union[str, Path]] = None,
+ special_token_map: Optional[Dict[str, str]] = None,
+ chat_template: Optional[str] = None,
+ ):
+ self._tokenizer: Tokenizer = None
+ self._chat_template: Optional[ChatTemplate] = None
+ self._special_token_map: Optional[Dict] = special_token_map or {}
+
+ if chat_template:
+ self.set_chat_template(chat_template)
+
+ if path:
+ self.load(path)
+
+ def load(self, path: Union[str, Path]):
+ """Load tokenizer from directory."""
+ path = Path(path)
+ tokenizer_file = path / "tokenizer.json"
+ config_file = path / "tokenizer_config.json"
+ self._tokenizer = Tokenizer.from_file(str(tokenizer_file))
+
+ if config_file.exists():
+ with open(config_file, "r", encoding="utf-8") as f:
+ config = json.load(f)
+
+ if "special_tokens" in config:
+ self._special_token_map.update(config["special_tokens"])
+
+ # Load chat template from config
+ if "chat_template" in config:
+ self.set_chat_template(config["chat_template"])
+
+ @classmethod
+ def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "TextTokenizer":
+ """Load tokenizer from pretrained directory."""
+ instance = cls(path)
+ return instance
+
+ def save_pretrained(self, tokenizer, save_path: str):
+ """
+ Save tokenizer to pretrained directory.
+
+ Args:
+ tokenizer: Tokenizer instance to save
+ save_path: Path to save the tokenizer
+ """
+ save_path = Path(save_path)
+ save_path.mkdir(parents=True, exist_ok=True)
+ self._tokenizer.save(tokenizer, save_path)
+
+ @classmethod
+ def register_tokenizer(cls, name: str, tokenizer_class: type):
+ """
+ Register a new tokenizer class.
+
+ Args:
+ name: Name to register the tokenizer class under
+ tokenizer_class: The tokenizer class to register
+ """
+ cls.TOKENIZER_CLASSES[name] = tokenizer_class
- @abstractmethod
def encode(
self,
tokens: Union[str, List[str]],
out_ids: bool = True,
- add_special_tokens: bool = False,
+ is_pretokenized: bool = False,
+ add_special_tokens: bool = True,
) -> List:
- pass
+ """Encode text to tokens or token IDs."""
+ if self._tokenizer is None:
+ raise RuntimeError(
+ "Tokenizer not initialized. Load or create a tokenizer first."
+ )
+
+ if isinstance(tokens, str):
+ encoded = self._tokenizer.encode(
+ tokens,
+ is_pretokenized=is_pretokenized,
+ add_special_tokens=add_special_tokens,
+ )
+ return encoded.ids if out_ids else encoded.tokens
+ else:
+ encoded_list = self._tokenizer.encode_batch(
+ tokens,
+ is_pretokenized=is_pretokenized,
+ add_special_tokens=add_special_tokens,
+ )
+ return [
+ encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
+ ]
- @abstractmethod
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
- pass
+ """Decode token IDs to text."""
+ if self._tokenizer is None:
+ raise RuntimeError(
+ "Tokenizer not initialized. Load or create a tokenizer first."
+ )
+
+ return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
- @abstractmethod
def __len__(self) -> int:
- pass
+ if self._tokenizer is None:
+ return 0
+ return self._tokenizer.get_vocab_size()
+
+ def __getattr__(self, key: str):
+ """
+ Dynamically intercept special token attribute access.
+ Supports three forms:
+ - tokenizer.bos_token → returns string
+ - tokenizer.bos_token_id → returns corresponding integer ID
+ - tokenizer.stop_ids → returns list of corresponding integer IDs
+ """
+ # Handle stop_ids
+ if key == "stop_ids":
+ return [
+ self._special_token_map.get(val)
+ for val in self._special_token_map.values()
+ ]
+
+ # Handle _id suffix (e.g., bos_token_id -> bos_token)
+ if key.endswith("_id"):
+ base_attr = key[:-3] # Remove "_id"
+ token_str = self._special_token_map.get(base_attr)
+ if token_str is None:
+ return None
+ if self._tokenizer is None:
+ raise RuntimeError("Tokenizer not loaded, cannot convert token to id.")
+ return self._tokenizer.token_to_id(token_str)
+
+ # Handle regular string attributes
+ if key in self._special_token_map:
+ return self._special_token_map.get(key)
+
+ # Other attributes trigger default AttributeError
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")
@property
- @abstractmethod
- def stop_ids(self) -> List[int]:
- pass
+ def vocab_size(self) -> int:
+ return len(self)
@property
- @abstractmethod
- def bos_id(self) -> int:
- pass
+ def pad_id(self) -> Optional[int]:
+ """Return the pad token ID if available."""
+ pad_token = self._special_token_map.get("pad")
+ if pad_token is None or self._tokenizer is None:
+ return None
+ return self._tokenizer.token_to_id(pad_token)
- @property
- @abstractmethod
- def eos_id(self) -> int:
- pass
+ def set_chat_template(self, template: Union[str, ChatTemplate]):
+ """
+ Set the chat template for the tokenizer.
- @property
- @abstractmethod
- def pad_id(self) -> int:
- pass
+ Args:
+ template: Either a template name (str) registered in the global registry,
+ or a ChatTemplate instance, or a Jinja2 template string.
+
+ Raises:
+ KeyError: If template name is not registered.
+ """
+ if isinstance(template, str):
+ self._chat_template = ChatTemplate.from_string(template)
+ elif isinstance(template, ChatTemplate):
+ self._chat_template = template
+ else:
+ raise ValueError("Invalid template type, must be str or ChatTemplate.")
+
+ def apply_chat_template(
+ self,
+ messages: List[Dict[str, str]],
+ system_prompt: Optional[str] = None,
+ tokenize: bool = True,
+ **kwargs,
+ ) -> Union[str, List[int]]:
+ """
+ Apply the chat template to messages and optionally tokenize the result.
+
+ Args:
+ messages: List of message dicts with 'role' and 'content'.
+ system_prompt: Optional system prompt string.
+ tokenize: Whether to return token IDs (True) or raw string (False).
+ **kwargs: Additional variables to pass to the template.
+
+ Returns:
+ Either the rendered string or list of token IDs.
+
+ Raises:
+ RuntimeError: If chat template is not set.
+ """
+ if self._chat_template is None:
+ raise RuntimeError(
+ "Chat template not set. Use set_chat_template() to set a template first."
+ )
+
+ # Render the template
+ rendered = self._chat_template.render(
+ messages=messages,
+ system_prompt=system_prompt,
+ **kwargs,
+ )
+
+ if tokenize:
+ return self.encode(rendered)
+
+ return rendered
-class BaseTrainer(ABC):
- def __init__(self, tokenizer: BaseTokenizer):
- self.tokenizer = tokenizer
+class BpeTokenizer(TextTokenizer):
+ """BPE tokenizer implementation."""
- @abstractmethod
- def train(self, files, vocab_size, min_freq, **kwargs):
- pass
-
- @abstractmethod
- def train_from_iterator(self, iterator, vocab_size, min_freq, **kwargs):
- pass
-
-
-class BpeTokenizer(BaseTokenizer):
def __init__(
self,
- control_tokens: List[str] = None,
- special_tokens: List[str] = None,
- path=None,
+ special_token_map: Dict[str, str] = None,
+ path: Optional[str] = None,
+ chat_template: Optional[str] = None,
):
- self._control_tokens = control_tokens or [
- "<|begin▁of▁sentence|>",
- "<|end▁of▁sentence|>",
- "<|▁pad▁|>",
- ]
- self._special_tokens = special_tokens or [
- "<|im▁start|>",
- "<|im▁end|>",
- ]
+ special_token_map = special_token_map or {
+ "bos": "<|begin▁of▁sentence|>",
+ "eos": "<|end▁of▁sentence|>",
+ "pad": "<|▁pad▁|>",
+ "im_start": "<|im▁start|>",
+ "im_end": "<|im▁end|>",
+ }
self._tokenizer = None
self._init_tokenizer()
- if path is not None:
- self.load(path)
+ super().__init__(
+ path, special_token_map=special_token_map, chat_template=chat_template
+ )
def _init_tokenizer(self):
+ """Initialize a new BPE tokenizer with default settings."""
model = BPE()
self._tokenizer = Tokenizer(model)
self._tokenizer.normalizer = normalizers.Sequence(
@@ -105,108 +262,3 @@ class BpeTokenizer(BaseTokenizer):
)
self._tokenizer.decoder = decoders.ByteLevel()
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
-
- def save(self, path):
- self._tokenizer.save(path)
-
- def load(self, path):
- self._tokenizer = Tokenizer.from_file(path)
-
- def encode(
- self,
- tokens: Union[str, List[str]],
- out_ids: bool = True,
- add_special_tokens: bool = False,
- ) -> List:
- if isinstance(tokens, str):
- encoded = self._tokenizer.encode(
- tokens, add_special_tokens=add_special_tokens
- )
- return encoded.ids if out_ids else encoded.tokens
- else:
- encoded_list = self._tokenizer.encode_batch(
- tokens, add_special_tokens=add_special_tokens
- )
- return [
- encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
- ]
-
- def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
- return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
-
- def __len__(self) -> int:
- return self._tokenizer.get_vocab_size()
-
- @property
- def stop_ids(self) -> List[int]:
- stop_token = self._control_tokens + self._special_tokens
- return [self._tokenizer.token_to_id(tok) for tok in stop_token]
-
- @property
- def bos_id(self) -> int:
- return self._tokenizer.token_to_id(self._control_tokens[0])
-
- @property
- def eos_id(self) -> int:
- return self._tokenizer.token_to_id(self._control_tokens[1])
-
- @property
- def pad_id(self) -> int:
- return self._tokenizer.token_to_id(self._control_tokens[2])
-
-
-class BpeTrainer(BaseTrainer):
- def __init__(self, tokenizer: BaseTokenizer):
- super().__init__(tokenizer)
-
- def _prepare_trainer(
- self,
- vocab_size: int,
- min_freq: int,
- reserved_token_size: int,
- max_token_length=18,
- ):
- assert reserved_token_size > len(self.tokenizer._special_tokens)
- reserved_tokens = [
- f"<|reserve{i:02d}|>"
- for i in range(reserved_token_size - len(self.tokenizer._special_tokens))
- ]
- detail_vocab_size = vocab_size - (
- len(reserved_tokens) + len(self.tokenizer._special_tokens)
- )
- alphabet = pre_tokenizers.ByteLevel.alphabet()
- min_size = len(alphabet) + len(self.tokenizer._control_tokens)
- assert detail_vocab_size > min_size
-
- trainer = BpeTrainerImpl(
- vocab_size=detail_vocab_size,
- min_frequency=min_freq,
- limit_alphabet=detail_vocab_size // 6,
- max_token_length=max_token_length,
- special_tokens=self.tokenizer._control_tokens,
- initial_alphabet=alphabet,
- show_progress=True,
- )
- return trainer, reserved_tokens
-
- def train(self, files, vocab_size, min_freq, reserved_token_size=100, **kwargs):
- trainer, reserved_tokens = self._prepare_trainer(
- vocab_size, min_freq, reserved_token_size, **kwargs
- )
- self.tokenizer._tokenizer.train(files=files, trainer=trainer)
- self.tokenizer._tokenizer.add_special_tokens(
- self.tokenizer._special_tokens + reserved_tokens
- )
-
- def train_from_iterator(
- self, iterator, vocab_size, min_freq, reserved_token_size=100, **kwargs
- ):
- trainer, reserved_tokens = self._prepare_trainer(
- vocab_size, min_freq, reserved_token_size, **kwargs
- )
- self.tokenizer._tokenizer.train_from_iterator(
- iterator=iterator, trainer=trainer
- )
- self.tokenizer._tokenizer.add_special_tokens(
- self.tokenizer._special_tokens + reserved_tokens
- )
diff --git a/astrai/tokenize/trainer.py b/astrai/tokenize/trainer.py
new file mode 100644
index 0000000..20e5536
--- /dev/null
+++ b/astrai/tokenize/trainer.py
@@ -0,0 +1,108 @@
+"""
+BPE Tokenizer Trainer module.
+
+Provides training functionality for BPE tokenizers.
+"""
+
+from typing import List, Union
+
+from tokenizers import pre_tokenizers
+from tokenizers.trainers import BpeTrainer as BpeTrainerImpl
+
+
+class BpeTrainer:
+ """BPE tokenizer trainer."""
+
+ def __init__(self, tokenizer):
+ """Initialize trainer with a tokenizer instance.
+
+ Args:
+ tokenizer: A BpeTokenizer instance
+ """
+ self.tokenizer = tokenizer
+
+ def _prepare_trainer(
+ self,
+ vocab_size: int,
+ min_freq: int,
+ reserved_token_size: int,
+ max_token_length: int = 18,
+ ):
+ """Prepare the BPE trainer with proper configuration."""
+ assert reserved_token_size > len(self.tokenizer._special_tokens)
+ reserved_tokens = [
+ f"<|reserve{i:02d}|>"
+ for i in range(reserved_token_size - len(self.tokenizer._special_tokens))
+ ]
+ detail_vocab_size = vocab_size - (
+ len(reserved_tokens) + len(self.tokenizer._special_tokens)
+ )
+ alphabet = pre_tokenizers.ByteLevel.alphabet()
+ min_size = len(alphabet) + len(self.tokenizer._control_tokens)
+ assert detail_vocab_size > min_size
+
+ trainer = BpeTrainerImpl(
+ vocab_size=detail_vocab_size,
+ min_frequency=min_freq,
+ limit_alphabet=detail_vocab_size // 6,
+ max_token_length=max_token_length,
+ special_tokens=self.tokenizer._control_tokens,
+ initial_alphabet=alphabet,
+ show_progress=True,
+ )
+ return trainer, reserved_tokens
+
+ def train(
+ self,
+ files: Union[str, List[str]],
+ vocab_size: int,
+ min_freq: int,
+ reserved_token_size: int = 100,
+ **kwargs,
+ ):
+ """Train tokenizer from files.
+
+ Args:
+ files: Path or list of paths to training files
+ vocab_size: Target vocabulary size
+ min_freq: Minimum frequency for tokens
+ reserved_token_size: Number of reserved tokens
+ **kwargs: Additional arguments
+ """
+ trainer, reserved_tokens = self._prepare_trainer(
+ vocab_size, min_freq, reserved_token_size, **kwargs
+ )
+ self.tokenizer._tokenizer.train(files=files, trainer=trainer)
+ self.tokenizer._tokenizer.add_special_tokens(
+ self.tokenizer._special_tokens + reserved_tokens
+ )
+
+ def train_from_iterator(
+ self,
+ iterator,
+ vocab_size: int,
+ min_freq: int,
+ reserved_token_size: int = 100,
+ **kwargs,
+ ):
+ """Train tokenizer from iterator.
+
+ Args:
+ iterator: Iterator yielding training strings
+ vocab_size: Target vocabulary size
+ min_freq: Minimum frequency for tokens
+ reserved_token_size: Number of reserved tokens
+ **kwargs: Additional arguments
+ """
+ trainer, reserved_tokens = self._prepare_trainer(
+ vocab_size, min_freq, reserved_token_size, **kwargs
+ )
+ self.tokenizer._tokenizer.train_from_iterator(
+ iterator=iterator, trainer=trainer
+ )
+ self.tokenizer._tokenizer.add_special_tokens(
+ self.tokenizer._special_tokens + reserved_tokens
+ )
+
+
+__all__ = ["BpeTrainer"]
diff --git a/scripts/demo/generate_ar.py b/scripts/demo/generate_ar.py
index a72a94c..698ecec 100644
--- a/scripts/demo/generate_ar.py
+++ b/scripts/demo/generate_ar.py
@@ -2,24 +2,32 @@ from pathlib import Path
import torch
-from astrai.config.param_config import ModelParameter
from astrai.inference import InferenceEngine
+from astrai.model import AutoModel
+from astrai.tokenize import AutoTokenizer
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def generate_text():
- param = ModelParameter.load(PARAMETER_ROOT, disable_init=True)
- param.to(device="cuda", dtype=torch.bfloat16)
+ # Load model from pretrained
+ model = AutoModel.from_pretrained(PARAMETER_ROOT)
+ model.to(device="cuda", dtype=torch.bfloat16)
+
+ # Load tokenizer from pretrained
+ tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT / "tokenizer")
query = input(">> ")
- engine = InferenceEngine(param)
+ engine = InferenceEngine(
+ model=model,
+ tokenizer=tokenizer,
+ )
response = engine.generate(
prompt=query,
stream=False,
- max_tokens=param.config.max_len,
+ max_tokens=2048,
temperature=0.8,
top_p=0.95,
top_k=50,
diff --git a/scripts/demo/generate_batch.py b/scripts/demo/generate_batch.py
index 754ab4c..0662d3a 100644
--- a/scripts/demo/generate_batch.py
+++ b/scripts/demo/generate_batch.py
@@ -2,7 +2,7 @@ from pathlib import Path
import torch
-from astrai.config.param_config import ModelParameter
+from astrai.model import AutoModel
from astrai.inference import InferenceEngine
PROJECT_ROOT = Path(__file__).resolve().parents[2]
@@ -10,8 +10,10 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def batch_generate():
- param = ModelParameter.load(PARAMETER_ROOT, disable_init=True)
- param.to(device="cuda", dtype=torch.bfloat16)
+ # Load model using AutoModel
+ model = AutoModel.from_pretrained(
+ PARAMETER_ROOT, device="cuda", dtype=torch.bfloat16
+ )
inputs = [
"你好",
@@ -21,11 +23,14 @@ def batch_generate():
"请问什么是显卡",
]
- engine = InferenceEngine(param)
+ engine = InferenceEngine(
+ model=model.model,
+ tokenizer=model.tokenizer,
+ )
responses = engine.generate(
prompt=inputs,
stream=False,
- max_tokens=param.config.max_len,
+ max_tokens=model.config.max_len,
temperature=0.8,
top_p=0.95,
top_k=50,
diff --git a/scripts/demo/stream_chat.py b/scripts/demo/stream_chat.py
index f06e4fd..1a9af74 100644
--- a/scripts/demo/stream_chat.py
+++ b/scripts/demo/stream_chat.py
@@ -1,32 +1,39 @@
from pathlib import Path
import torch
-
-from astrai.config.param_config import ModelParameter
from astrai.inference import InferenceEngine
+from astrai.model import AutoModel
+from astrai.tokenize import AutoTokenizer
+
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def chat():
- param = ModelParameter.load(PARAMETER_ROOT, disable_init=True)
- param.to(device="cuda", dtype=torch.bfloat16)
+ model = AutoModel.from_pretrained(PARAMETER_ROOT)
+ tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
+ model.to(device="cuda", dtype=torch.bfloat16)
- history = []
- engine = InferenceEngine(param)
+ messages = []
+ engine = InferenceEngine(model=model, tokenizer=tokenizer)
while True:
query = input(">> ")
if query == "!exit":
break
+ # Add user message
+ messages.append({"role": "user", "content": query})
+
+ # Generate response
full_response = ""
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False)
for token in engine.generate(
- prompt=query,
+ prompt=prompt,
stream=True,
- max_tokens=param.config.max_len,
+ max_tokens=model.config.max_len,
temperature=0.8,
top_p=0.95,
top_k=50,
@@ -35,7 +42,8 @@ def chat():
full_response += token
print()
- history.append((query, full_response.strip()))
+ # Add assistant response to messages
+ messages.append({"role": "assistant", "content": full_response.strip()})
if __name__ == "__main__":
diff --git a/scripts/tools/generate.py b/scripts/tools/generate.py
index 7011ec5..641c4ad 100644
--- a/scripts/tools/generate.py
+++ b/scripts/tools/generate.py
@@ -3,7 +3,8 @@ import json
import torch
-from astrai.config.param_config import ModelParameter
+from astrai.model import AutoModel
+from astrai.tokenize import AutoTokenizer
from astrai.inference import InferenceEngine
@@ -17,9 +18,9 @@ def processor(
question_key: str,
response_key: str,
):
- param = ModelParameter.load(model_dir, disable_init=True)
- param.to(device="cuda", dtype=torch.bfloat16)
- engine = InferenceEngine(param)
+ # Load model using AutoModel
+ model = AutoModel.from_pretrained(model_dir, device="cuda", dtype=torch.bfloat16)
+ engine = InferenceEngine(model=model.model, tokenizer=model.tokenizer)
with open(input_json_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f]
@@ -29,7 +30,7 @@ def processor(
responses = engine.generate(
prompt=queries,
stream=False,
- max_tokens=param.config.max_len,
+ max_tokens=model.config.max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
diff --git a/scripts/tools/perplexity.py b/scripts/tools/perplexity.py
index a67a231..ebc4060 100644
--- a/scripts/tools/perplexity.py
+++ b/scripts/tools/perplexity.py
@@ -7,7 +7,7 @@ import torch.nn.functional as F
import tqdm
from torch import Tensor
-from astrai.config.param_config import ModelParameter
+from astrai.model import AutoModel
def compute_perplexity(
@@ -20,7 +20,7 @@ def compute_perplexity(
where PPL = exp(-(1/N) * sum(log P(w_i | w_