diff --git a/.gitattributes b/.gitattributes index 1de0c60..60472b8 100644 --- a/.gitattributes +++ b/.gitattributes @@ -3,6 +3,7 @@ # Files that MUST use LF (Unix/Linux execution) *.sh text eol=lf +*.py text eol=lf Dockerfile text eol=lf .dockerignore text eol=lf diff --git a/assets/docs/params.md b/assets/docs/params.md index 9d4bf4b..6c6f550 100644 --- a/assets/docs/params.md +++ b/assets/docs/params.md @@ -90,8 +90,12 @@ from astrai.inference import InferenceEngine, GenerationRequest param = ModelParameter.load("your_model_dir") param.to(device="cuda", dtype=torch.bfloat16) -# Create engine -engine = InferenceEngine(param) +# Create engine with separate model and tokenizer +engine = InferenceEngine( + model=param.model, + tokenizer=param.tokenizer, + config=param.config, +) # Build request request = GenerationRequest( diff --git a/astrai/__init__.py b/astrai/__init__.py index 423ab84..b98a372 100644 --- a/astrai/__init__.py +++ b/astrai/__init__.py @@ -12,7 +12,7 @@ from astrai.inference import ( GenerationRequest, InferenceEngine, ) -from astrai.model.transformer import Transformer +from astrai.model import AutoModel, Transformer from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer __all__ = [ @@ -27,4 +27,5 @@ __all__ = [ "StrategyFactory", "SchedulerFactory", "BaseFactory", + "AutoModel", ] diff --git a/astrai/config/__init__.py b/astrai/config/__init__.py index 7b90769..bd47b59 100644 --- a/astrai/config/__init__.py +++ b/astrai/config/__init__.py @@ -1,11 +1,7 @@ from astrai.config.model_config import ModelConfig -from astrai.config.param_config import BaseModelIO, ModelParameter from astrai.config.train_config import TrainConfig __all__ = [ - # Base I/O - "BaseModelIO", - "ModelParameter", # Model configuration "ModelConfig", "TrainConfig", diff --git a/astrai/config/model_config.py b/astrai/config/model_config.py index 9789cd1..cf1ebb8 100644 --- a/astrai/config/model_config.py +++ b/astrai/config/model_config.py @@ -6,6 +6,7 @@ from typing import Optional, Self @dataclass class ModelConfig: # basic config + model_type: Optional[str] = None vocab_size: Optional[int] = None dim: Optional[int] = None diff --git a/astrai/config/param_config.py b/astrai/config/param_config.py deleted file mode 100644 index 9f15acf..0000000 --- a/astrai/config/param_config.py +++ /dev/null @@ -1,114 +0,0 @@ -from contextlib import contextmanager -from dataclasses import dataclass, field -from pathlib import Path -from typing import Self, Union - -import safetensors.torch as st -import torch.nn as nn - -from astrai.config.model_config import ModelConfig -from astrai.tokenize import BpeTokenizer -from astrai.model.transformer import Transformer - - -@contextmanager -def disable_random_init(enable: bool = True): - init_functions = [ - "xavier_normal_", - "xavier_uniform_", - "kaiming_normal_", - "kaiming_uniform_", - "zeros_", - "ones_", - "constant_", - "normal_", - "uniform_", - ] - original_funcs = {} - for name in init_functions: - if enable and hasattr(nn.init, name): - original_funcs[name] = getattr(nn.init, name) - setattr(nn.init, name, lambda *args, **kwargs: None) - try: - yield - finally: - if enable: - for name, orig_func in original_funcs.items(): - setattr(nn.init, name, orig_func) - - -@dataclass -class BaseModelIO: - """Base class for model I/O operations.""" - - model: nn.Module = field( - default_factory=nn.Identity, metadata={"help": "Transformer model."} - ) - tokenizer: BpeTokenizer = field( - default_factory=BpeTokenizer, metadata={"help": "Tokenizer for the model."} - ) - config: ModelConfig = field( - default_factory=ModelConfig, - metadata={"help": "Transformer model configuration."}, - ) - - def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]: - """Get standardized file paths for model components.""" - dir_path = Path(directory) - return { - "model": dir_path / "model.safetensors", - "config": dir_path / "config.json", - "tokenizer": dir_path / "tokenizer.json", - } - - def save_components(self, save_dir: Union[str, Path]): - """Save core model components.""" - paths = self._get_file_paths(save_dir) - paths["model"].parent.mkdir(parents=True, exist_ok=True) - - if self.model is not None: - st.save_file(self.model.state_dict(), str(paths["model"])) - - self.config.save(str(paths["config"])) - self.tokenizer.save(str(paths["tokenizer"])) - - def load_components( - self, load_dir: Union[str, Path], disable_init: bool = False - ) -> Self: - """Load core model components.""" - paths = self._get_file_paths(load_dir) - - self.config.load(str(paths["config"])) - self.tokenizer.load(str(paths["tokenizer"])) - - if isinstance(self.model, nn.Identity): - with disable_random_init(enable=disable_init): - self.model = Transformer(self.config) - - if paths["model"].exists(): - state_dict = st.load_file(str(paths["model"])) - self.model.load_state_dict(state_dict) - - return self - - def to(self, *args, **kwargs) -> "BaseModelIO": - """Move model to device.""" - if self.model is not None: - self.model.to(*args, **kwargs) - return self - - -@dataclass -class ModelParameter(BaseModelIO): - """Container for model parameters with serialization capabilities.""" - - @classmethod - def save(cls, instance: "ModelParameter", save_dir: Union[str, Path]): - instance.save_components(save_dir) - - @classmethod - def load( - cls, load_dir: Union[str, Path], disable_init: bool = False - ) -> "ModelParameter": - instance = cls() - return instance.load_components(load_dir, disable_init=disable_init) diff --git a/astrai/factory.py b/astrai/factory.py index 6c88c15..2109113 100644 --- a/astrai/factory.py +++ b/astrai/factory.py @@ -185,3 +185,6 @@ class BaseFactory(ABC, Generic[T]): def list_by_priority(cls, reverse: bool = False) -> List[str]: """List registered component names sorted by priority.""" return cls._registry.list_by_priority(reverse) + + +__all__ = ["Registry", "BaseFactory"] diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 5c86c53..14cc073 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -1,11 +1,11 @@ """Unified inference engine.""" import threading +import torch +import torch.nn as nn from typing import Any, Dict, Generator, List, Optional, Union -from astrai.config import ModelParameter -from astrai.tokenize.chat_template import build_prompt - +from astrai.tokenize.tokenizer import TextTokenizer from astrai.inference.scheduler import InferenceScheduler @@ -14,22 +14,18 @@ class GenerationRequest: def __init__( self, - query: Union[str, List[str]], + messages: List[Dict[str, str]], top_k: int = 50, top_p: float = 1.0, temperature: float = 1.0, max_len: int = 1024, - history: Optional[Any] = None, - system_prompt: Optional[str] = None, stream: bool = False, ): - self.query = query + self.messages = messages self.top_k = top_k self.top_p = top_p self.temperature = temperature self.max_len = max_len - self.history = history - self.system_prompt = system_prompt self.stream = stream self._validate() @@ -107,26 +103,41 @@ class InferenceEngine: def __init__( self, - parameter: ModelParameter, - max_batch_size: int = 16, + model: nn.Module, + tokenizer: TextTokenizer, + max_batch_size: int = 1, max_seq_len: Optional[int] = None, ): - self.model = parameter.model - self.tokenizer = parameter.tokenizer - self.config = parameter.config + """ + Initialize inference engine with separate model and tokenizer. - model_params = next(self.model.parameters()) - self.device = model_params.device - self.dtype = model_params.dtype + Args: + model: The language model for inference (nn.Module, e.g., Transformer) + tokenizer: The tokenizer for encoding/decoding text + config: Model configuration + max_batch_size: Maximum batch size for continuous batching + max_seq_len: Maximum sequence length (defaults to config.max_len) + """ + self.model = model + self.tokenizer = tokenizer + + # Get device and dtype from model parameters + try: + first_param = next(model.parameters()) + device = first_param.device + dtype = first_param.dtype + except StopIteration: + # Model has no parameters, use default device/dtype + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 self.scheduler = InferenceScheduler( model=self.model, tokenizer=self.tokenizer, - config=self.config, max_batch_size=max_batch_size, max_seq_len=max_seq_len, - device=self.device, - dtype=self.dtype, + device=device, + dtype=dtype, ) self.kv_cache = self.scheduler.kv_cache @@ -160,7 +171,8 @@ class InferenceEngine: self, request: GenerationRequest ) -> Union[Generator[str, None, None], str, List[str]]: """Generate with GenerationRequest object.""" - prompt = build_prompt(request.query, request.history) + # Use tokenizer's chat template with messages + prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False) return self.generate( prompt=prompt, diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 58ceaca..7e44b07 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -8,7 +8,8 @@ from typing import Any, Callable, Dict, List, Optional import torch from torch import Tensor -from astrai.config import ModelConfig +from astrai.model.automodel import AutoModel +from astrai.tokenize.tokenizer import TextTokenizer class TaskStatus: @@ -98,23 +99,23 @@ class InferenceScheduler: def __init__( self, - model, - tokenizer, - config: ModelConfig, + model: AutoModel, + tokenizer: TextTokenizer, max_batch_size: int = 16, max_seq_len: Optional[int] = None, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, ): + config = model.config + self.model = model self.tokenizer = tokenizer - self.config = config self.max_batch_size = max_batch_size self.max_seq_len = max_seq_len or config.max_len - self.device = device - self.dtype = dtype + self.device = device or next(model.parameters()).device + self.dtype = dtype or next(model.parameters()).dtype - num_heads = config.n_kv_heads + num_kv_heads = config.n_kv_heads head_dim = config.dim // config.n_heads n_layers = config.n_layers @@ -123,26 +124,26 @@ class InferenceScheduler: max_batch_size, self.max_seq_len, n_layers, - num_heads, + num_kv_heads, head_dim, ), - device=device, - dtype=dtype, + device=self.device, + dtype=self.dtype, ) v_cache = torch.empty( ( max_batch_size, self.max_seq_len, n_layers, - num_heads, + num_kv_heads, head_dim, ), - device=device, - dtype=dtype, + device=self.device, + dtype=self.dtype, ) self.kv_cache = (k_cache, v_cache) self.seq_mask = torch.ones( - (max_batch_size, self.max_seq_len), device=device, dtype=torch.bool + (max_batch_size, self.max_seq_len), device=self.device, dtype=torch.bool ) self.waiting_queue: List[Task] = [] @@ -259,7 +260,7 @@ class InferenceScheduler: ) with torch.inference_mode(): - outputs = self.model( + self.model( input_ids, input_mask=input_mask, start_pos=0, diff --git a/astrai/inference/server.py b/astrai/inference/server.py index c780add..aa6b801 100644 --- a/astrai/inference/server.py +++ b/astrai/inference/server.py @@ -3,8 +3,6 @@ Inference Server with Continuous Batching Support FastAPI server for inference with continuous batching. Provides OpenAI-compatible chat completion endpoints. - -Author: AstrAI Team """ import json @@ -19,14 +17,15 @@ from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field -from astrai.config.param_config import ModelParameter -from astrai.inference.engine import GenerationRequest, InferenceEngine +from astrai.inference.engine import InferenceEngine +from astrai.model import AutoModel +from astrai.tokenize import TextTokenizer logger = logging.getLogger(__name__) # Global model parameter and engine (loaded once) -_model_param: Optional[ModelParameter] = None _engine: Optional[InferenceEngine] = None +_model_param: Optional[Any] = None _project_root = Path(__file__).parent.parent.parent # Server configuration (set before running server) @@ -95,13 +94,17 @@ def load_model( param_path = _project_root / "params" if not param_path.exists(): raise FileNotFoundError(f"Parameter directory not found: {param_path}") - _model_param = ModelParameter.load(param_path, disable_init=True) + + # Load tokenizer separately + tokenizer = TextTokenizer.from_pretrained(param_path) + _model_param = AutoModel.from_pretrained(param_path, tokenizer=tokenizer) _model_param.to(device=device, dtype=dtype) logger.info(f"Model loaded on {device} with dtype {dtype}") - # Initialize inference engine with continuous batching + # Initialize inference engine with separate model and tokenizer _engine = InferenceEngine( - parameter=_model_param, + model=_model_param, + tokenizer=tokenizer, max_batch_size=max_batch_size, ) logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}") @@ -164,27 +167,43 @@ def convert_messages_to_history( return system_prompt, history if history else None -def convert_messages_to_prompt(messages: List[ChatMessage]) -> str: +def convert_messages_to_prompt( + messages: List[ChatMessage], engine: InferenceEngine = None +) -> str: """Convert messages to prompt string. Args: messages: List of ChatMessage objects + engine: InferenceEngine instance for accessing tokenizer Returns: str: Formatted prompt string """ - system_prompt, history = convert_messages_to_history(messages) + # Convert to dict format for chat template + msg_dicts = [{"role": m.role, "content": m.content} for m in messages] - # Get the last user message as query - user_messages = [m.content for m in messages if m.role == "user"] - if not user_messages: - raise ValueError("No user message found") - query = user_messages[-1] + # Extract system prompt if present + system_prompt = None + filtered_messages = [] + for msg in msg_dicts: + if msg["role"] == "system": + system_prompt = msg["content"] + else: + filtered_messages.append(msg) - # Build prompt using chat template - from astrai.tokenize.chat_template import build_prompt + # Use engine's tokenizer chat template if available + if engine is not None and engine.tokenizer is not None: + return engine.tokenizer.apply_chat_template( + filtered_messages, system_prompt=system_prompt, tokenize=False + ) - return build_prompt(query, history) + # Fallback: simple concatenation (deprecated) + prompt_parts = [] + for msg in filtered_messages: + prompt_parts.append( + f"<|im▁start|>{msg['role']}\n{msg['content']}<|im▁end|>" + ) + return "\n".join(prompt_parts) + "\n<|im▁start|>assistant\n" @app.get("/health") @@ -213,8 +232,8 @@ async def chat_completion(request: ChatCompletionRequest): if _engine is None: raise HTTPException(status_code=503, detail="Engine not initialized") - # Convert messages to prompt - prompt = convert_messages_to_prompt(request.messages) + # Convert messages to prompt using engine's tokenizer + prompt = convert_messages_to_prompt(request.messages, engine=_engine) if request.stream: # Streaming response (use synchronous generator) @@ -294,15 +313,18 @@ async def generate( if _engine is None: raise HTTPException(status_code=503, detail="Engine not initialized") - # Convert history format - hist: Optional[List[Tuple[str, str]]] = None + # Build messages for chat template + messages = [] if history: - hist = [(h[0], h[1]) for h in history] + # Convert history format: List[List[str]] -> List[Dict] + for h in history: + if len(h) >= 2: + messages.append({"role": "user", "content": h[0]}) + messages.append({"role": "assistant", "content": h[1]}) + messages.append({"role": "user", "content": query}) - # Build prompt - from astrai.tokenize.chat_template import build_prompt - - prompt = build_prompt(query, hist) + # Use tokenizer's chat template + prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False) if stream: # Synchronous streaming diff --git a/astrai/model/__init__.py b/astrai/model/__init__.py index 516ccd3..35d74cc 100644 --- a/astrai/model/__init__.py +++ b/astrai/model/__init__.py @@ -6,5 +6,17 @@ from astrai.model.module import ( RMSNorm, ) from astrai.model.transformer import Transformer +from astrai.model.automodel import AutoModel -__all__ = ["Linear", "RMSNorm", "MLP", "GQA", "DecoderBlock", "Transformer"] + +__all__ = [ + # Modules + "Linear", + "RMSNorm", + "MLP", + "GQA", + "DecoderBlock", + # Models + "Transformer", + "AutoModel", +] diff --git a/astrai/model/automodel.py b/astrai/model/automodel.py new file mode 100644 index 0000000..b2cbad1 --- /dev/null +++ b/astrai/model/automodel.py @@ -0,0 +1,134 @@ +""" +AutoModel base class for model loading and saving. +""" + +import torch.nn as nn +import safetensors.torch as st + +from pathlib import Path +from contextlib import contextmanager +from typing import Self, Union, Dict, Type + +from astrai.config import ModelConfig + + +@contextmanager +def _disable_random_init(enable: bool = True): + init_functions = [ + "xavier_normal_", + "xavier_uniform_", + "kaiming_normal_", + "kaiming_uniform_", + "zeros_", + "ones_", + "constant_", + "normal_", + "uniform_", + ] + original_funcs = {} + for name in init_functions: + if enable and hasattr(nn.init, name): + original_funcs[name] = getattr(nn.init, name) + setattr(nn.init, name, lambda *args, **kwargs: None) + try: + yield + finally: + if enable: + for name, orig_func in original_funcs.items(): + setattr(nn.init, name, orig_func) + + +class AutoModel(nn.Module): + """ + Autoregressive language model base class. + Provides model loading/saving and generation capabilities. + """ + + # Model registry - stored as class attribute + _registry: Dict[str, Type["AutoModel"]] = {} + + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + + @classmethod + def register(cls, model_type: str): + """ + Class method decorator to register model type. + + Usage: + @AutoModel.register('transformer') + class Transformer(AutoModel): + ... + """ + + def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]: + cls._registry[model_type.lower()] = sub_cls + return sub_cls + + return decorator + + @classmethod + def get_model_class(cls, model_type: str) -> Type["AutoModel"]: + """Get model class by model_type string.""" + model_type = model_type.lower() + if model_type not in cls._registry: + available = list(cls._registry.keys()) + raise ValueError( + f"Unknown model_type: {model_type}. Available: {available}" + ) + return cls._registry[model_type] + + @classmethod + def from_pretrained( + cls, + path: Union[str, Path], + disable_random_init: bool = True, + ) -> nn.Module: + + model_path = Path(path) + + # Load config + config = ModelConfig() + config_path = model_path / "config.json" + if config_path.exists(): + config.load(str(config_path)) + else: + raise FileNotFoundError(f"Config file not found: {config_path}") + + # If called from base class, use model_type to determine actual model class + if cls is AutoModel: + model_type = config.model_type or "transformer" + actual_cls = cls.get_model_class(model_type) + else: + raise ValueError( + f"Cannot call from_pretrained() on subclass {cls.__name__}" + ) + + with _disable_random_init(enable=disable_random_init): + model = actual_cls(config) + + # Load weights + weights_path = model_path / "model.safetensors" + if weights_path.exists(): + state_dict = st.load_file(str(weights_path)) + model.load_state_dict(state_dict, strict=False) + + return model + + def save_pretrained( + self, + save_directory: Union[str, Path], + ) -> None: + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + + # Save config + self.config.save(str(save_path / "config.json")) + + # Save weights + st.save_file(self.state_dict(), str(save_path / "model.safetensors")) + + def to(self, *args, **kwargs) -> Self: + """Move model to device/dtype.""" + return super().to(*args, **kwargs) diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 281cd4a..2198a4b 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -5,6 +5,7 @@ import torch.nn as nn from torch import Tensor from astrai.config.model_config import ModelConfig +from astrai.model.automodel import AutoModel from astrai.model.module import ( DecoderBlock, Embedding, @@ -66,9 +67,14 @@ def process_attention_mask( return attention_mask -class Transformer(nn.Module): +@AutoModel.register("transformer") +class Transformer(AutoModel): + """ + Transformer language model. + """ + def __init__(self, config: ModelConfig): - super().__init__() + super().__init__(config) self.config = config self.rotary_embeding = RotaryEmbedding( config.dim // config.n_heads, config.max_len @@ -97,16 +103,27 @@ class Transformer(nn.Module): if self.config.tie_weight: self.lm_head.weight = self.embed_tokens.weight - self._init_parameters() + self._init_weights() + + def _init_weights(self): + for param in self.parameters(): + if param.dim() > 1: + nn.init.normal_(param, mean=0.0, std=0.006) def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False): lm_head_key = "lm_head.weight" embed_key = "embed_tokens.weight" + # Make a copy to avoid modifying the original state_dict + state_dict = dict(state_dict) + if self.config.tie_weight: # same tensor - state_dict[lm_head_key] = state_dict[embed_key] + if embed_key in state_dict: + state_dict[lm_head_key] = state_dict[embed_key] else: + # If lm_head.weight exists in checkpoint, use it directly + # If not, copy from embed_tokens.weight if lm_head_key not in state_dict and embed_key in state_dict: # use clone to avoid sharing the same tensor state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) @@ -125,11 +142,6 @@ class Transformer(nn.Module): return state_dict - def _init_parameters(self): - for param in self.parameters(): - if param.dim() > 1: - nn.init.normal_(param, mean=0.0, std=0.006) - def forward( self, input_ids: Tensor, diff --git a/astrai/tokenize/__init__.py b/astrai/tokenize/__init__.py index 5c7fe0c..e98db15 100644 --- a/astrai/tokenize/__init__.py +++ b/astrai/tokenize/__init__.py @@ -1,22 +1,23 @@ from astrai.tokenize.tokenizer import ( - BaseTokenizer, + TextTokenizer, BpeTokenizer, - BaseTrainer, - BpeTrainer, ) +from astrai.tokenize.trainer import BpeTrainer from astrai.tokenize.chat_template import ( + ChatTemplate, HistoryType, MessageType, - build_prompt, ) +# Alias for compatibility +AutoTokenizer = TextTokenizer + __all__ = [ - "BaseTokenizer", + "TextTokenizer", + "AutoTokenizer", "BpeTokenizer", - "BaseTrainer", "BpeTrainer", + "ChatTemplate", "HistoryType", "MessageType", - "CHAT_TEMPLATES", - "build_prompt", ] diff --git a/astrai/tokenize/chat_template.py b/astrai/tokenize/chat_template.py index aa04f2f..6d4b2ee 100644 --- a/astrai/tokenize/chat_template.py +++ b/astrai/tokenize/chat_template.py @@ -1,7 +1,6 @@ from typing import Dict, List, Optional, Tuple, Any from jinja2 import Template from dataclasses import dataclass -from astrai.factory import Registry HistoryType = List[Tuple[str, str]] MessageType = Dict[str, str] @@ -75,213 +74,3 @@ class ChatTemplate: jinja_template = Template(self.template_str) return jinja_template.render(**variables) - - -# Global registry instance -_default_registry = Registry() - -# Default template name -_default_template_name = "chatml" - - -# Convenience functions -def register_chat_template( - name: str, - template_str: str, - description: str = "", - default_variables: Optional[Dict[str, Any]] = None, - special_tokens: Optional[Dict[str, str]] = None, -) -> ChatTemplate: - """Register a chat template in the global registry.""" - template = ChatTemplate( - name=name, - template_str=template_str, - description=description, - default_variables=default_variables, - special_tokens=special_tokens, - ) - _default_registry.register(name, template, category=None, priority=0) - return template - - -def set_default_chat_template(name: str) -> None: - """Set the default chat template name globally.""" - global _default_template_name - if not _default_registry.contains(name): - raise KeyError( - f"Chat template '{name}' not found. Available: {list(_default_registry.list_names())}" - ) - _default_template_name = name - - -def get_default_chat_template_name() -> str: - """Get the current default chat template name.""" - return _default_template_name - - -def get_chat_template(name: str) -> ChatTemplate: - """Get a chat template from the global registry.""" - return _default_registry.get(name) - - -def list_chat_templates() -> List[str]: - """List all registered chat template names.""" - return _default_registry.list_names() - - -def chat_template_exists(name: str) -> bool: - """Check if a chat template exists.""" - return _default_registry.contains(name) - - -def build_prompt( - query: str, - system_prompt: Optional[str] = None, - history: Optional[HistoryType] = None, - template: Optional[str] = None, - template_name: Optional[str] = None, - **extra_variables: Any, -) -> str: - """Build prompt using a registered chat template or a custom template string. - - This function maintains backward compatibility with the previous API. - - Args: - query: The current user query. - system_prompt: Optional system prompt. - history: Optional list of (user_msg, assistant_msg) pairs. - template: If provided, uses this exact Jinja2 template string (overrides template_name). - template_name: Name of a registered template to use (ignored if `template` is given). - If None, uses the globally set default template (see `set_default_chat_template`). - **extra_variables: Additional variables to pass to the template. - - Returns: - Rendered prompt string. - - Raises: - KeyError: If `template_name` is not registered. - """ - # Convert history to message format - messages: List[MessageType] = [] - if history: - for user_msg, assistant_msg in history: - messages.append({"role": "user", "content": user_msg}) - messages.append({"role": "assistant", "content": assistant_msg}) - messages.append({"role": "user", "content": query}) - - if template is not None: - # Use the provided template string directly - jinja_template = Template(template) - variables = {"messages": messages, **extra_variables} - if system_prompt is not None: - variables["system_prompt"] = system_prompt - return jinja_template.render(**variables) - else: - # Determine which template name to use - if template_name is None: - template_name = _default_template_name - # Use a registered template - chat_template = get_chat_template(template_name) - return chat_template.render( - messages=messages, - system_prompt=system_prompt, - **extra_variables, - ) - - -# Predefined templates -# ChatML template (original) -register_chat_template( - name="chatml", - template_str=( - "{%- if system_prompt -%}\n" - "{{ bos_token }}system\n" - "{{ system_prompt }}{{ eos_token }}\n" - "{%- endif -%}\n" - "{%- for message in messages -%}\n" - "{{ bos_token }}{{ message['role'] }}\n" - "{{ message['content'] }}{{ eos_token }}\n" - "{%- endfor -%}\n" - "{{ bos_token }}assistant\n" - ), - description="ChatML format with configurable special tokens.", - special_tokens={"bos_token": "<|im▁start|>", "eos_token": "<|im▁end|>"}, -) - -# Simplified template without special tokens (plain text) -register_chat_template( - name="plain", - template_str=( - "{%- if system_prompt -%}\n" - "System: {{ system_prompt }}\n" - "{%- endif -%}\n" - "{%- for message in messages -%}\n" - "{{ message['role']|capitalize }}: {{ message['content'] }}\n" - "{%- endfor -%}\n" - "Assistant:" - ), - description="Plain text format with role labels.", -) - -# Alpaca-style template -register_chat_template( - name="alpaca", - template_str=( - "{%- if system_prompt -%}\n" - "### Instruction:\n" - "{{ system_prompt }}\n" - "{%- endif -%}\n" - "### Input:\n" - "{{ messages[-1]['content'] }}\n" - "### Response:" - ), - description="Alpaca instruction‑response format (single‑turn).", - default_variables={}, -) - -# OpenAI chat format (approximation) -register_chat_template( - name="openai", - template_str=( - "{%- if system_prompt -%}\n" - "{{ bos_token }}system\n" - "{{ system_prompt }}{{ eos_token }}\n" - "{%- endif -%}\n" - "{%- for message in messages -%}\n" - "{{ bos_token }}{{ message['role'] }}\n" - "{{ message['content'] }}{{ eos_token }}\n" - "{%- endfor -%}\n" - "{{ bos_token }}assistant\n" - ), - description="OpenAI‑compatible chat format with configurable special tokens.", - special_tokens={"bos_token": "<|im▁start|>", "eos_token": "<|im▁end|>"}, -) - -# Llama‑2 style with [INST] tags -register_chat_template( - name="llama2", - template_str=( - "{%- if system_prompt -%}\n" - "<>\n" - "{{ system_prompt }}\n" - "<>\n" - "{%- endif -%}\n" - "[INST] {{ messages[-1]['content'] }} [/INST]" - ), - description="Llama‑2 style with [INST] tags (single‑turn).", - default_variables={}, -) - - -__all__ = [ - "ChatTemplate", - "register_chat_template", - "get_chat_template", - "list_chat_templates", - "chat_template_exists", - "build_prompt", - "set_default_chat_template", - "get_default_chat_template_name", - "HistoryType", - "MessageType", -] diff --git a/astrai/tokenize/tokenizer.py b/astrai/tokenize/tokenizer.py index 96cab47..c6e2fc7 100644 --- a/astrai/tokenize/tokenizer.py +++ b/astrai/tokenize/tokenizer.py @@ -1,97 +1,254 @@ -from abc import ABC, abstractmethod -from typing import List, Union +""" +Tokenizer module with implementation and auto-loading support. +""" + +import json +from pathlib import Path +from typing import Dict, List, Optional, Union from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors from tokenizers.models import BPE -from tokenizers.trainers import BpeTrainer as BpeTrainerImpl +from astrai.tokenize.chat_template import ChatTemplate -class BaseTokenizer(ABC): - @abstractmethod - def _init_tokenizer(self): - pass +class TextTokenizer: + """Base tokenizer class with automatic loading support""" - @abstractmethod - def save(self, path): - pass + TOKENIZER_CLASSES = {} # Registry for auto-loading - @abstractmethod - def load(self, path): - pass + def __init__( + self, + path: Optional[Union[str, Path]] = None, + special_token_map: Optional[Dict[str, str]] = None, + chat_template: Optional[str] = None, + ): + self._tokenizer: Tokenizer = None + self._chat_template: Optional[ChatTemplate] = None + self._special_token_map: Optional[Dict] = special_token_map or {} + + if chat_template: + self.set_chat_template(chat_template) + + if path: + self.load(path) + + def load(self, path: Union[str, Path]): + """Load tokenizer from directory.""" + path = Path(path) + tokenizer_file = path / "tokenizer.json" + config_file = path / "tokenizer_config.json" + self._tokenizer = Tokenizer.from_file(str(tokenizer_file)) + + if config_file.exists(): + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) + + if "special_tokens" in config: + self._special_token_map.update(config["special_tokens"]) + + # Load chat template from config + if "chat_template" in config: + self.set_chat_template(config["chat_template"]) + + @classmethod + def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "TextTokenizer": + """Load tokenizer from pretrained directory.""" + instance = cls(path) + return instance + + def save_pretrained(self, tokenizer, save_path: str): + """ + Save tokenizer to pretrained directory. + + Args: + tokenizer: Tokenizer instance to save + save_path: Path to save the tokenizer + """ + save_path = Path(save_path) + save_path.mkdir(parents=True, exist_ok=True) + self._tokenizer.save(tokenizer, save_path) + + @classmethod + def register_tokenizer(cls, name: str, tokenizer_class: type): + """ + Register a new tokenizer class. + + Args: + name: Name to register the tokenizer class under + tokenizer_class: The tokenizer class to register + """ + cls.TOKENIZER_CLASSES[name] = tokenizer_class - @abstractmethod def encode( self, tokens: Union[str, List[str]], out_ids: bool = True, - add_special_tokens: bool = False, + is_pretokenized: bool = False, + add_special_tokens: bool = True, ) -> List: - pass + """Encode text to tokens or token IDs.""" + if self._tokenizer is None: + raise RuntimeError( + "Tokenizer not initialized. Load or create a tokenizer first." + ) + + if isinstance(tokens, str): + encoded = self._tokenizer.encode( + tokens, + is_pretokenized=is_pretokenized, + add_special_tokens=add_special_tokens, + ) + return encoded.ids if out_ids else encoded.tokens + else: + encoded_list = self._tokenizer.encode_batch( + tokens, + is_pretokenized=is_pretokenized, + add_special_tokens=add_special_tokens, + ) + return [ + encoded.ids if out_ids else encoded.tokens for encoded in encoded_list + ] - @abstractmethod def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str: - pass + """Decode token IDs to text.""" + if self._tokenizer is None: + raise RuntimeError( + "Tokenizer not initialized. Load or create a tokenizer first." + ) + + return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) - @abstractmethod def __len__(self) -> int: - pass + if self._tokenizer is None: + return 0 + return self._tokenizer.get_vocab_size() + + def __getattr__(self, key: str): + """ + Dynamically intercept special token attribute access. + Supports three forms: + - tokenizer.bos_token → returns string + - tokenizer.bos_token_id → returns corresponding integer ID + - tokenizer.stop_ids → returns list of corresponding integer IDs + """ + # Handle stop_ids + if key == "stop_ids": + return [ + self._special_token_map.get(val) + for val in self._special_token_map.values() + ] + + # Handle _id suffix (e.g., bos_token_id -> bos_token) + if key.endswith("_id"): + base_attr = key[:-3] # Remove "_id" + token_str = self._special_token_map.get(base_attr) + if token_str is None: + return None + if self._tokenizer is None: + raise RuntimeError("Tokenizer not loaded, cannot convert token to id.") + return self._tokenizer.token_to_id(token_str) + + # Handle regular string attributes + if key in self._special_token_map: + return self._special_token_map.get(key) + + # Other attributes trigger default AttributeError + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'") @property - @abstractmethod - def stop_ids(self) -> List[int]: - pass + def vocab_size(self) -> int: + return len(self) @property - @abstractmethod - def bos_id(self) -> int: - pass + def pad_id(self) -> Optional[int]: + """Return the pad token ID if available.""" + pad_token = self._special_token_map.get("pad") + if pad_token is None or self._tokenizer is None: + return None + return self._tokenizer.token_to_id(pad_token) - @property - @abstractmethod - def eos_id(self) -> int: - pass + def set_chat_template(self, template: Union[str, ChatTemplate]): + """ + Set the chat template for the tokenizer. - @property - @abstractmethod - def pad_id(self) -> int: - pass + Args: + template: Either a template name (str) registered in the global registry, + or a ChatTemplate instance, or a Jinja2 template string. + + Raises: + KeyError: If template name is not registered. + """ + if isinstance(template, str): + self._chat_template = ChatTemplate.from_string(template) + elif isinstance(template, ChatTemplate): + self._chat_template = template + else: + raise ValueError("Invalid template type, must be str or ChatTemplate.") + + def apply_chat_template( + self, + messages: List[Dict[str, str]], + system_prompt: Optional[str] = None, + tokenize: bool = True, + **kwargs, + ) -> Union[str, List[int]]: + """ + Apply the chat template to messages and optionally tokenize the result. + + Args: + messages: List of message dicts with 'role' and 'content'. + system_prompt: Optional system prompt string. + tokenize: Whether to return token IDs (True) or raw string (False). + **kwargs: Additional variables to pass to the template. + + Returns: + Either the rendered string or list of token IDs. + + Raises: + RuntimeError: If chat template is not set. + """ + if self._chat_template is None: + raise RuntimeError( + "Chat template not set. Use set_chat_template() to set a template first." + ) + + # Render the template + rendered = self._chat_template.render( + messages=messages, + system_prompt=system_prompt, + **kwargs, + ) + + if tokenize: + return self.encode(rendered) + + return rendered -class BaseTrainer(ABC): - def __init__(self, tokenizer: BaseTokenizer): - self.tokenizer = tokenizer +class BpeTokenizer(TextTokenizer): + """BPE tokenizer implementation.""" - @abstractmethod - def train(self, files, vocab_size, min_freq, **kwargs): - pass - - @abstractmethod - def train_from_iterator(self, iterator, vocab_size, min_freq, **kwargs): - pass - - -class BpeTokenizer(BaseTokenizer): def __init__( self, - control_tokens: List[str] = None, - special_tokens: List[str] = None, - path=None, + special_token_map: Dict[str, str] = None, + path: Optional[str] = None, + chat_template: Optional[str] = None, ): - self._control_tokens = control_tokens or [ - "<|begin▁of▁sentence|>", - "<|end▁of▁sentence|>", - "<|▁pad▁|>", - ] - self._special_tokens = special_tokens or [ - "<|im▁start|>", - "<|im▁end|>", - ] + special_token_map = special_token_map or { + "bos": "<|begin▁of▁sentence|>", + "eos": "<|end▁of▁sentence|>", + "pad": "<|▁pad▁|>", + "im_start": "<|im▁start|>", + "im_end": "<|im▁end|>", + } self._tokenizer = None self._init_tokenizer() - if path is not None: - self.load(path) + super().__init__( + path, special_token_map=special_token_map, chat_template=chat_template + ) def _init_tokenizer(self): + """Initialize a new BPE tokenizer with default settings.""" model = BPE() self._tokenizer = Tokenizer(model) self._tokenizer.normalizer = normalizers.Sequence( @@ -105,108 +262,3 @@ class BpeTokenizer(BaseTokenizer): ) self._tokenizer.decoder = decoders.ByteLevel() self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) - - def save(self, path): - self._tokenizer.save(path) - - def load(self, path): - self._tokenizer = Tokenizer.from_file(path) - - def encode( - self, - tokens: Union[str, List[str]], - out_ids: bool = True, - add_special_tokens: bool = False, - ) -> List: - if isinstance(tokens, str): - encoded = self._tokenizer.encode( - tokens, add_special_tokens=add_special_tokens - ) - return encoded.ids if out_ids else encoded.tokens - else: - encoded_list = self._tokenizer.encode_batch( - tokens, add_special_tokens=add_special_tokens - ) - return [ - encoded.ids if out_ids else encoded.tokens for encoded in encoded_list - ] - - def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str: - return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) - - def __len__(self) -> int: - return self._tokenizer.get_vocab_size() - - @property - def stop_ids(self) -> List[int]: - stop_token = self._control_tokens + self._special_tokens - return [self._tokenizer.token_to_id(tok) for tok in stop_token] - - @property - def bos_id(self) -> int: - return self._tokenizer.token_to_id(self._control_tokens[0]) - - @property - def eos_id(self) -> int: - return self._tokenizer.token_to_id(self._control_tokens[1]) - - @property - def pad_id(self) -> int: - return self._tokenizer.token_to_id(self._control_tokens[2]) - - -class BpeTrainer(BaseTrainer): - def __init__(self, tokenizer: BaseTokenizer): - super().__init__(tokenizer) - - def _prepare_trainer( - self, - vocab_size: int, - min_freq: int, - reserved_token_size: int, - max_token_length=18, - ): - assert reserved_token_size > len(self.tokenizer._special_tokens) - reserved_tokens = [ - f"<|reserve{i:02d}|>" - for i in range(reserved_token_size - len(self.tokenizer._special_tokens)) - ] - detail_vocab_size = vocab_size - ( - len(reserved_tokens) + len(self.tokenizer._special_tokens) - ) - alphabet = pre_tokenizers.ByteLevel.alphabet() - min_size = len(alphabet) + len(self.tokenizer._control_tokens) - assert detail_vocab_size > min_size - - trainer = BpeTrainerImpl( - vocab_size=detail_vocab_size, - min_frequency=min_freq, - limit_alphabet=detail_vocab_size // 6, - max_token_length=max_token_length, - special_tokens=self.tokenizer._control_tokens, - initial_alphabet=alphabet, - show_progress=True, - ) - return trainer, reserved_tokens - - def train(self, files, vocab_size, min_freq, reserved_token_size=100, **kwargs): - trainer, reserved_tokens = self._prepare_trainer( - vocab_size, min_freq, reserved_token_size, **kwargs - ) - self.tokenizer._tokenizer.train(files=files, trainer=trainer) - self.tokenizer._tokenizer.add_special_tokens( - self.tokenizer._special_tokens + reserved_tokens - ) - - def train_from_iterator( - self, iterator, vocab_size, min_freq, reserved_token_size=100, **kwargs - ): - trainer, reserved_tokens = self._prepare_trainer( - vocab_size, min_freq, reserved_token_size, **kwargs - ) - self.tokenizer._tokenizer.train_from_iterator( - iterator=iterator, trainer=trainer - ) - self.tokenizer._tokenizer.add_special_tokens( - self.tokenizer._special_tokens + reserved_tokens - ) diff --git a/astrai/tokenize/trainer.py b/astrai/tokenize/trainer.py new file mode 100644 index 0000000..20e5536 --- /dev/null +++ b/astrai/tokenize/trainer.py @@ -0,0 +1,108 @@ +""" +BPE Tokenizer Trainer module. + +Provides training functionality for BPE tokenizers. +""" + +from typing import List, Union + +from tokenizers import pre_tokenizers +from tokenizers.trainers import BpeTrainer as BpeTrainerImpl + + +class BpeTrainer: + """BPE tokenizer trainer.""" + + def __init__(self, tokenizer): + """Initialize trainer with a tokenizer instance. + + Args: + tokenizer: A BpeTokenizer instance + """ + self.tokenizer = tokenizer + + def _prepare_trainer( + self, + vocab_size: int, + min_freq: int, + reserved_token_size: int, + max_token_length: int = 18, + ): + """Prepare the BPE trainer with proper configuration.""" + assert reserved_token_size > len(self.tokenizer._special_tokens) + reserved_tokens = [ + f"<|reserve{i:02d}|>" + for i in range(reserved_token_size - len(self.tokenizer._special_tokens)) + ] + detail_vocab_size = vocab_size - ( + len(reserved_tokens) + len(self.tokenizer._special_tokens) + ) + alphabet = pre_tokenizers.ByteLevel.alphabet() + min_size = len(alphabet) + len(self.tokenizer._control_tokens) + assert detail_vocab_size > min_size + + trainer = BpeTrainerImpl( + vocab_size=detail_vocab_size, + min_frequency=min_freq, + limit_alphabet=detail_vocab_size // 6, + max_token_length=max_token_length, + special_tokens=self.tokenizer._control_tokens, + initial_alphabet=alphabet, + show_progress=True, + ) + return trainer, reserved_tokens + + def train( + self, + files: Union[str, List[str]], + vocab_size: int, + min_freq: int, + reserved_token_size: int = 100, + **kwargs, + ): + """Train tokenizer from files. + + Args: + files: Path or list of paths to training files + vocab_size: Target vocabulary size + min_freq: Minimum frequency for tokens + reserved_token_size: Number of reserved tokens + **kwargs: Additional arguments + """ + trainer, reserved_tokens = self._prepare_trainer( + vocab_size, min_freq, reserved_token_size, **kwargs + ) + self.tokenizer._tokenizer.train(files=files, trainer=trainer) + self.tokenizer._tokenizer.add_special_tokens( + self.tokenizer._special_tokens + reserved_tokens + ) + + def train_from_iterator( + self, + iterator, + vocab_size: int, + min_freq: int, + reserved_token_size: int = 100, + **kwargs, + ): + """Train tokenizer from iterator. + + Args: + iterator: Iterator yielding training strings + vocab_size: Target vocabulary size + min_freq: Minimum frequency for tokens + reserved_token_size: Number of reserved tokens + **kwargs: Additional arguments + """ + trainer, reserved_tokens = self._prepare_trainer( + vocab_size, min_freq, reserved_token_size, **kwargs + ) + self.tokenizer._tokenizer.train_from_iterator( + iterator=iterator, trainer=trainer + ) + self.tokenizer._tokenizer.add_special_tokens( + self.tokenizer._special_tokens + reserved_tokens + ) + + +__all__ = ["BpeTrainer"] diff --git a/scripts/demo/generate_ar.py b/scripts/demo/generate_ar.py index a72a94c..698ecec 100644 --- a/scripts/demo/generate_ar.py +++ b/scripts/demo/generate_ar.py @@ -2,24 +2,32 @@ from pathlib import Path import torch -from astrai.config.param_config import ModelParameter from astrai.inference import InferenceEngine +from astrai.model import AutoModel +from astrai.tokenize import AutoTokenizer PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") def generate_text(): - param = ModelParameter.load(PARAMETER_ROOT, disable_init=True) - param.to(device="cuda", dtype=torch.bfloat16) + # Load model from pretrained + model = AutoModel.from_pretrained(PARAMETER_ROOT) + model.to(device="cuda", dtype=torch.bfloat16) + + # Load tokenizer from pretrained + tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT / "tokenizer") query = input(">> ") - engine = InferenceEngine(param) + engine = InferenceEngine( + model=model, + tokenizer=tokenizer, + ) response = engine.generate( prompt=query, stream=False, - max_tokens=param.config.max_len, + max_tokens=2048, temperature=0.8, top_p=0.95, top_k=50, diff --git a/scripts/demo/generate_batch.py b/scripts/demo/generate_batch.py index 754ab4c..0662d3a 100644 --- a/scripts/demo/generate_batch.py +++ b/scripts/demo/generate_batch.py @@ -2,7 +2,7 @@ from pathlib import Path import torch -from astrai.config.param_config import ModelParameter +from astrai.model import AutoModel from astrai.inference import InferenceEngine PROJECT_ROOT = Path(__file__).resolve().parents[2] @@ -10,8 +10,10 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params") def batch_generate(): - param = ModelParameter.load(PARAMETER_ROOT, disable_init=True) - param.to(device="cuda", dtype=torch.bfloat16) + # Load model using AutoModel + model = AutoModel.from_pretrained( + PARAMETER_ROOT, device="cuda", dtype=torch.bfloat16 + ) inputs = [ "你好", @@ -21,11 +23,14 @@ def batch_generate(): "请问什么是显卡", ] - engine = InferenceEngine(param) + engine = InferenceEngine( + model=model.model, + tokenizer=model.tokenizer, + ) responses = engine.generate( prompt=inputs, stream=False, - max_tokens=param.config.max_len, + max_tokens=model.config.max_len, temperature=0.8, top_p=0.95, top_k=50, diff --git a/scripts/demo/stream_chat.py b/scripts/demo/stream_chat.py index f06e4fd..1a9af74 100644 --- a/scripts/demo/stream_chat.py +++ b/scripts/demo/stream_chat.py @@ -1,32 +1,39 @@ from pathlib import Path import torch - -from astrai.config.param_config import ModelParameter from astrai.inference import InferenceEngine +from astrai.model import AutoModel +from astrai.tokenize import AutoTokenizer + PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") def chat(): - param = ModelParameter.load(PARAMETER_ROOT, disable_init=True) - param.to(device="cuda", dtype=torch.bfloat16) + model = AutoModel.from_pretrained(PARAMETER_ROOT) + tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT) + model.to(device="cuda", dtype=torch.bfloat16) - history = [] - engine = InferenceEngine(param) + messages = [] + engine = InferenceEngine(model=model, tokenizer=tokenizer) while True: query = input(">> ") if query == "!exit": break + # Add user message + messages.append({"role": "user", "content": query}) + + # Generate response full_response = "" + prompt = tokenizer.apply_chat_template(messages, tokenize=False) for token in engine.generate( - prompt=query, + prompt=prompt, stream=True, - max_tokens=param.config.max_len, + max_tokens=model.config.max_len, temperature=0.8, top_p=0.95, top_k=50, @@ -35,7 +42,8 @@ def chat(): full_response += token print() - history.append((query, full_response.strip())) + # Add assistant response to messages + messages.append({"role": "assistant", "content": full_response.strip()}) if __name__ == "__main__": diff --git a/scripts/tools/generate.py b/scripts/tools/generate.py index 7011ec5..641c4ad 100644 --- a/scripts/tools/generate.py +++ b/scripts/tools/generate.py @@ -3,7 +3,8 @@ import json import torch -from astrai.config.param_config import ModelParameter +from astrai.model import AutoModel +from astrai.tokenize import AutoTokenizer from astrai.inference import InferenceEngine @@ -17,9 +18,9 @@ def processor( question_key: str, response_key: str, ): - param = ModelParameter.load(model_dir, disable_init=True) - param.to(device="cuda", dtype=torch.bfloat16) - engine = InferenceEngine(param) + # Load model using AutoModel + model = AutoModel.from_pretrained(model_dir, device="cuda", dtype=torch.bfloat16) + engine = InferenceEngine(model=model.model, tokenizer=model.tokenizer) with open(input_json_file, "r", encoding="utf-8") as f: input_data = [json.loads(line) for line in f] @@ -29,7 +30,7 @@ def processor( responses = engine.generate( prompt=queries, stream=False, - max_tokens=param.config.max_len, + max_tokens=model.config.max_len, temperature=temperature, top_p=top_p, top_k=top_k, diff --git a/scripts/tools/perplexity.py b/scripts/tools/perplexity.py index a67a231..ebc4060 100644 --- a/scripts/tools/perplexity.py +++ b/scripts/tools/perplexity.py @@ -7,7 +7,7 @@ import torch.nn.functional as F import tqdm from torch import Tensor -from astrai.config.param_config import ModelParameter +from astrai.model import AutoModel def compute_perplexity( @@ -20,7 +20,7 @@ def compute_perplexity( where PPL = exp(-(1/N) * sum(log P(w_i | w_