refactor: 从data 模块分离tokenizer

This commit is contained in:
ViperEkura 2026-04-04 16:12:58 +08:00
parent b531232a9b
commit bd9741dc5f
16 changed files with 108 additions and 92 deletions

View File

@ -6,7 +6,8 @@ from astrai.config import (
TrainConfig,
)
from astrai.factory import BaseFactory
from astrai.data import BpeTokenizer, DatasetFactory
from astrai.data import DatasetFactory
from astrai.tokenizer import BpeTokenizer
from astrai.inference.generator import (
BatchGenerator,
EmbeddingEncoder,

View File

@ -7,7 +7,7 @@ import safetensors.torch as st
import torch.nn as nn
from astrai.config.model_config import ModelConfig
from astrai.data.tokenizer import BpeTokenizer
from astrai.tokenizer import BpeTokenizer
from astrai.model.transformer import Transformer

View File

@ -1,29 +1,19 @@
from astrai.data.dataset import (
BaseDataset,
DatasetFactory,
DPODataset,
GRPODataset,
BaseSegmentFetcher,
MultiSegmentFetcher,
SEQDataset,
SFTDataset,
)
from astrai.data.sampler import ResumableDistributedSampler
from astrai.data.tokenizer import BpeTokenizer
__all__ = [
# Base classes
"BaseDataset",
# Dataset implementations
"SEQDataset",
"SFTDataset",
"DPODataset",
"GRPODataset",
# Factory
"DatasetFactory",
# Fetchers
"BaseSegmentFetcher",
"MultiSegmentFetcher",
# Factory (DatasetFactory is alias for backward compatibility)
"DatasetFactory",
"DatasetFactory",
# Tokenizer and sampler
"BpeTokenizer",
# Sampler
"ResumableDistributedSampler",
]

View File

@ -9,7 +9,7 @@ from torch import Tensor
from torch.utils.data import Dataset
from astrai.factory import BaseFactory
from astrai.data.serialization import load_h5
from astrai.serialization import load_h5
class BaseSegmentFetcher:

View File

@ -1,78 +1,13 @@
from dataclasses import dataclass
from typing import Dict, Generator, List, Optional, Tuple, Union
from typing import Generator, List, Optional, Tuple, Union
import torch
from jinja2 import Template
from torch import Tensor
from astrai.config.param_config import ModelParameter
from astrai.factory import BaseFactory
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
HistoryType = List[Tuple[str, str]]
MessageType = Dict[str, str]
# Predefined chat templates using jinja2
CHAT_TEMPLATES: Dict[str, str] = {
"chatml": """{%- if system_prompt -%}
<imstart>system
{{ system_prompt }}<imend>
{%- endif -%}
{%- for message in messages -%}
<imstart>{{ message['role'] }}
{{ message['content'] }}<imend>
{%- endfor -%}
<imstart>assistant
""",
}
def build_prompt(
query: str,
system_prompt: Optional[str] = None,
history: Optional[HistoryType] = None,
template: Optional[str] = None,
) -> str:
"""Build prompt using jinja2 template for query and history.
Args:
query (str): query string.
system_prompt (Optional[str]): system prompt string.
history (Optional[HistoryType]): history list of query and response.
template (Optional[str]): jinja2 template string. If None, uses default chatml template.
Returns:
str: prompt string formatted according to the template.
Example:
# Use default template
prompt = build_prompt(query="Hello", history=[...])
# Use custom template
custom_template = '''
{%- for msg in messages -%}
{{ msg['role'] }}: {{ msg['content'] }}
{%- endfor -%}
'''
prompt = build_prompt(query="Hello", template=custom_template)
"""
# Convert history to message format
messages: List[MessageType] = []
if history:
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": query})
# Use provided template or default chatml template
template_str = template if template is not None else CHAT_TEMPLATES["chatml"]
# Render template
jinja_template = Template(template_str)
return jinja_template.render(
messages=messages,
system_prompt=system_prompt,
)
from astrai.tokenizer.chat_template import HistoryType, build_prompt
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:

View File

@ -0,0 +1,23 @@
from astrai.tokenizer.tokenizer import (
BaseTokenizer,
BpeTokenizer,
BaseTrainer,
BpeTrainer,
)
from astrai.tokenizer.chat_template import (
HistoryType,
MessageType,
CHAT_TEMPLATES,
build_prompt,
)
__all__ = [
"BaseTokenizer",
"BpeTokenizer",
"BaseTrainer",
"BpeTrainer",
"HistoryType",
"MessageType",
"CHAT_TEMPLATES",
"build_prompt",
]

View File

@ -0,0 +1,67 @@
from typing import Dict, List, Optional, Tuple
from jinja2 import Template
HistoryType = List[Tuple[str, str]]
MessageType = Dict[str, str]
# Predefined chat templates using jinja2
CHAT_TEMPLATES: Dict[str, str] = {
"chatml": """{%- if system_prompt -%}
<imstart>system
{{ system_prompt }}<imend>
{%- endif -%}
{%- for message in messages -%}
<imstart>{{ message['role'] }}
{{ message['content'] }}<imend>
{%- endfor -%}
<imstart>assistant
""",
}
def build_prompt(
query: str,
system_prompt: Optional[str] = None,
history: Optional[HistoryType] = None,
template: Optional[str] = None,
) -> str:
"""Build prompt using jinja2 template for query and history.
Args:
query (str): query string.
system_prompt (Optional[str]): system prompt string.
history (Optional[HistoryType]): history list of query and response.
template (Optional[str]): jinja2 template string. If None, uses default chatml template.
Returns:
str: prompt string formatted according to the template.
Example:
# Use default template
prompt = build_prompt(query="Hello", history=[...])
# Use custom template
custom_template = '''
{%- for msg in messages -%}
{{ msg['role'] }}: {{ msg['content'] }}
{%- endfor -%}
'''
prompt = build_prompt(query="Hello", template=custom_template)
"""
# Convert history to message format
messages: List[MessageType] = []
if history:
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": query})
# Use provided template or default chatml template
template_str = template if template is not None else CHAT_TEMPLATES["chatml"]
# Render template
jinja_template = Template(template_str)
return jinja_template.render(
messages=messages,
system_prompt=system_prompt,
)

View File

@ -8,7 +8,7 @@ import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
from astrai.data.serialization import Checkpoint
from astrai.serialization import Checkpoint
from astrai.parallel import only_on_rank
from astrai.trainer.metric_util import (
ctx_get_grad_max,

View File

@ -8,7 +8,7 @@ from torch.utils.data import DataLoader
from astrai.config.train_config import TrainConfig
from astrai.data import ResumableDistributedSampler
from astrai.data.serialization import Checkpoint
from astrai.serialization import Checkpoint
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.trainer.strategy import BaseStrategy, StrategyFactory

View File

@ -2,7 +2,7 @@ import logging
from typing import List, Optional
from astrai.config import TrainConfig
from astrai.data.serialization import Checkpoint
from astrai.serialization import Checkpoint
from astrai.parallel.setup import spawn_parallel_fn
from astrai.trainer.train_callback import (
TrainCallback,

View File

@ -11,7 +11,7 @@ from tokenizers import pre_tokenizers
from torch.utils.data import Dataset
from astrai.config.model_config import ModelConfig
from astrai.data.tokenizer import BpeTokenizer, BpeTrainer
from astrai.tokenizer import BpeTokenizer, BpeTrainer
from astrai.model.transformer import Transformer

View File

@ -5,7 +5,7 @@ import torch.distributed as dist
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from astrai.data.serialization import Checkpoint
from astrai.serialization import Checkpoint
from astrai.parallel.setup import get_rank, spawn_parallel_fn

View File

@ -2,7 +2,7 @@ import numpy as np
import torch
from astrai.data.dataset import DatasetFactory
from astrai.data.serialization import save_h5
from astrai.serialization import save_h5
def test_dataset_loader_random_paths(base_test_env):

View File

@ -4,7 +4,7 @@ import numpy as np
import torch
from astrai.config.train_config import TrainConfig
from astrai.data.serialization import Checkpoint
from astrai.serialization import Checkpoint
from astrai.trainer.schedule import SchedulerFactory
from astrai.trainer.trainer import Trainer