From bd9741dc5fed8acfb3bace96369ef0223bf6dae4 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 4 Apr 2026 16:12:58 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BB=8Edata=20=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=E5=88=86=E7=A6=BBtokenizer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/__init__.py | 3 +- astrai/config/param_config.py | 2 +- astrai/data/__init__.py | 20 ++----- astrai/data/dataset.py | 2 +- astrai/inference/generator.py | 69 +------------------------ astrai/{data => }/serialization.py | 0 astrai/tokenizer/__init__.py | 23 +++++++++ astrai/tokenizer/chat_template.py | 67 ++++++++++++++++++++++++ astrai/{data => tokenizer}/tokenizer.py | 0 astrai/trainer/train_callback.py | 2 +- astrai/trainer/train_context.py | 2 +- astrai/trainer/trainer.py | 2 +- tests/conftest.py | 2 +- tests/data/test_checkpoint.py | 2 +- tests/data/test_dataset.py | 2 +- tests/trainer/test_early_stopping.py | 2 +- 16 files changed, 108 insertions(+), 92 deletions(-) rename astrai/{data => }/serialization.py (100%) create mode 100644 astrai/tokenizer/__init__.py create mode 100644 astrai/tokenizer/chat_template.py rename astrai/{data => tokenizer}/tokenizer.py (100%) diff --git a/astrai/__init__.py b/astrai/__init__.py index a8191f4..65ab6a6 100644 --- a/astrai/__init__.py +++ b/astrai/__init__.py @@ -6,7 +6,8 @@ from astrai.config import ( TrainConfig, ) from astrai.factory import BaseFactory -from astrai.data import BpeTokenizer, DatasetFactory +from astrai.data import DatasetFactory +from astrai.tokenizer import BpeTokenizer from astrai.inference.generator import ( BatchGenerator, EmbeddingEncoder, diff --git a/astrai/config/param_config.py b/astrai/config/param_config.py index 5a4b9a3..7afc9bb 100644 --- a/astrai/config/param_config.py +++ b/astrai/config/param_config.py @@ -7,7 +7,7 @@ import safetensors.torch as st import torch.nn as nn from astrai.config.model_config import ModelConfig -from astrai.data.tokenizer import BpeTokenizer +from astrai.tokenizer import BpeTokenizer from astrai.model.transformer import Transformer diff --git a/astrai/data/__init__.py b/astrai/data/__init__.py index 337788c..d8df820 100644 --- a/astrai/data/__init__.py +++ b/astrai/data/__init__.py @@ -1,29 +1,19 @@ from astrai.data.dataset import ( BaseDataset, DatasetFactory, - DPODataset, - GRPODataset, + BaseSegmentFetcher, MultiSegmentFetcher, - SEQDataset, - SFTDataset, ) from astrai.data.sampler import ResumableDistributedSampler -from astrai.data.tokenizer import BpeTokenizer __all__ = [ # Base classes "BaseDataset", - # Dataset implementations - "SEQDataset", - "SFTDataset", - "DPODataset", - "GRPODataset", + # Factory + "DatasetFactory", # Fetchers + "BaseSegmentFetcher", "MultiSegmentFetcher", - # Factory (DatasetFactory is alias for backward compatibility) - "DatasetFactory", - "DatasetFactory", - # Tokenizer and sampler - "BpeTokenizer", + # Sampler "ResumableDistributedSampler", ] diff --git a/astrai/data/dataset.py b/astrai/data/dataset.py index 25a289b..66a4b65 100644 --- a/astrai/data/dataset.py +++ b/astrai/data/dataset.py @@ -9,7 +9,7 @@ from torch import Tensor from torch.utils.data import Dataset from astrai.factory import BaseFactory -from astrai.data.serialization import load_h5 +from astrai.serialization import load_h5 class BaseSegmentFetcher: diff --git a/astrai/inference/generator.py b/astrai/inference/generator.py index 397e5e8..68c2e2e 100644 --- a/astrai/inference/generator.py +++ b/astrai/inference/generator.py @@ -1,78 +1,13 @@ from dataclasses import dataclass -from typing import Dict, Generator, List, Optional, Tuple, Union +from typing import Generator, List, Optional, Tuple, Union import torch -from jinja2 import Template from torch import Tensor from astrai.config.param_config import ModelParameter from astrai.factory import BaseFactory from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager - -HistoryType = List[Tuple[str, str]] -MessageType = Dict[str, str] - -# Predefined chat templates using jinja2 -CHAT_TEMPLATES: Dict[str, str] = { - "chatml": """{%- if system_prompt -%} -<|im▁start|>system -{{ system_prompt }}<|im▁end|> -{%- endif -%} -{%- for message in messages -%} -<|im▁start|>{{ message['role'] }} -{{ message['content'] }}<|im▁end|> -{%- endfor -%} -<|im▁start|>assistant -""", -} - - -def build_prompt( - query: str, - system_prompt: Optional[str] = None, - history: Optional[HistoryType] = None, - template: Optional[str] = None, -) -> str: - """Build prompt using jinja2 template for query and history. - - Args: - query (str): query string. - system_prompt (Optional[str]): system prompt string. - history (Optional[HistoryType]): history list of query and response. - template (Optional[str]): jinja2 template string. If None, uses default chatml template. - - Returns: - str: prompt string formatted according to the template. - - Example: - # Use default template - prompt = build_prompt(query="Hello", history=[...]) - - # Use custom template - custom_template = ''' - {%- for msg in messages -%} - {{ msg['role'] }}: {{ msg['content'] }} - {%- endfor -%} - ''' - prompt = build_prompt(query="Hello", template=custom_template) - """ - # Convert history to message format - messages: List[MessageType] = [] - if history: - for user_msg, assistant_msg in history: - messages.append({"role": "user", "content": user_msg}) - messages.append({"role": "assistant", "content": assistant_msg}) - messages.append({"role": "user", "content": query}) - - # Use provided template or default chatml template - template_str = template if template is not None else CHAT_TEMPLATES["chatml"] - - # Render template - jinja_template = Template(template_str) - return jinja_template.render( - messages=messages, - system_prompt=system_prompt, - ) +from astrai.tokenizer.chat_template import HistoryType, build_prompt def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]: diff --git a/astrai/data/serialization.py b/astrai/serialization.py similarity index 100% rename from astrai/data/serialization.py rename to astrai/serialization.py diff --git a/astrai/tokenizer/__init__.py b/astrai/tokenizer/__init__.py new file mode 100644 index 0000000..6f31455 --- /dev/null +++ b/astrai/tokenizer/__init__.py @@ -0,0 +1,23 @@ +from astrai.tokenizer.tokenizer import ( + BaseTokenizer, + BpeTokenizer, + BaseTrainer, + BpeTrainer, +) +from astrai.tokenizer.chat_template import ( + HistoryType, + MessageType, + CHAT_TEMPLATES, + build_prompt, +) + +__all__ = [ + "BaseTokenizer", + "BpeTokenizer", + "BaseTrainer", + "BpeTrainer", + "HistoryType", + "MessageType", + "CHAT_TEMPLATES", + "build_prompt", +] diff --git a/astrai/tokenizer/chat_template.py b/astrai/tokenizer/chat_template.py new file mode 100644 index 0000000..f22ffef --- /dev/null +++ b/astrai/tokenizer/chat_template.py @@ -0,0 +1,67 @@ +from typing import Dict, List, Optional, Tuple +from jinja2 import Template + +HistoryType = List[Tuple[str, str]] +MessageType = Dict[str, str] + +# Predefined chat templates using jinja2 +CHAT_TEMPLATES: Dict[str, str] = { + "chatml": """{%- if system_prompt -%} +<|im▁start|>system +{{ system_prompt }}<|im▁end|> +{%- endif -%} +{%- for message in messages -%} +<|im▁start|>{{ message['role'] }} +{{ message['content'] }}<|im▁end|> +{%- endfor -%} +<|im▁start|>assistant +""", +} + + +def build_prompt( + query: str, + system_prompt: Optional[str] = None, + history: Optional[HistoryType] = None, + template: Optional[str] = None, +) -> str: + """Build prompt using jinja2 template for query and history. + + Args: + query (str): query string. + system_prompt (Optional[str]): system prompt string. + history (Optional[HistoryType]): history list of query and response. + template (Optional[str]): jinja2 template string. If None, uses default chatml template. + + Returns: + str: prompt string formatted according to the template. + + Example: + # Use default template + prompt = build_prompt(query="Hello", history=[...]) + + # Use custom template + custom_template = ''' + {%- for msg in messages -%} + {{ msg['role'] }}: {{ msg['content'] }} + {%- endfor -%} + ''' + prompt = build_prompt(query="Hello", template=custom_template) + """ + # Convert history to message format + messages: List[MessageType] = [] + if history: + for user_msg, assistant_msg in history: + messages.append({"role": "user", "content": user_msg}) + messages.append({"role": "assistant", "content": assistant_msg}) + messages.append({"role": "user", "content": query}) + + # Use provided template or default chatml template + template_str = template if template is not None else CHAT_TEMPLATES["chatml"] + + # Render template + jinja_template = Template(template_str) + return jinja_template.render( + messages=messages, + system_prompt=system_prompt, + ) diff --git a/astrai/data/tokenizer.py b/astrai/tokenizer/tokenizer.py similarity index 100% rename from astrai/data/tokenizer.py rename to astrai/tokenizer/tokenizer.py diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 6ca09ae..c79508d 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -8,7 +8,7 @@ import torch.nn as nn from torch.nn.utils import clip_grad_norm_ from tqdm import tqdm -from astrai.data.serialization import Checkpoint +from astrai.serialization import Checkpoint from astrai.parallel import only_on_rank from astrai.trainer.metric_util import ( ctx_get_grad_max, diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 864cf52..abdb1fd 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -8,7 +8,7 @@ from torch.utils.data import DataLoader from astrai.config.train_config import TrainConfig from astrai.data import ResumableDistributedSampler -from astrai.data.serialization import Checkpoint +from astrai.serialization import Checkpoint from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.trainer.strategy import BaseStrategy, StrategyFactory diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index 189e1dd..817f258 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -2,7 +2,7 @@ import logging from typing import List, Optional from astrai.config import TrainConfig -from astrai.data.serialization import Checkpoint +from astrai.serialization import Checkpoint from astrai.parallel.setup import spawn_parallel_fn from astrai.trainer.train_callback import ( TrainCallback, diff --git a/tests/conftest.py b/tests/conftest.py index ed44218..cd1f09c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ from tokenizers import pre_tokenizers from torch.utils.data import Dataset from astrai.config.model_config import ModelConfig -from astrai.data.tokenizer import BpeTokenizer, BpeTrainer +from astrai.tokenizer import BpeTokenizer, BpeTrainer from astrai.model.transformer import Transformer diff --git a/tests/data/test_checkpoint.py b/tests/data/test_checkpoint.py index 1285a1e..5e4d6f9 100644 --- a/tests/data/test_checkpoint.py +++ b/tests/data/test_checkpoint.py @@ -5,7 +5,7 @@ import torch.distributed as dist from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR -from astrai.data.serialization import Checkpoint +from astrai.serialization import Checkpoint from astrai.parallel.setup import get_rank, spawn_parallel_fn diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 580b3ea..4b3f4e2 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -2,7 +2,7 @@ import numpy as np import torch from astrai.data.dataset import DatasetFactory -from astrai.data.serialization import save_h5 +from astrai.serialization import save_h5 def test_dataset_loader_random_paths(base_test_env): diff --git a/tests/trainer/test_early_stopping.py b/tests/trainer/test_early_stopping.py index 052fd2a..c2d84c5 100644 --- a/tests/trainer/test_early_stopping.py +++ b/tests/trainer/test_early_stopping.py @@ -4,7 +4,7 @@ import numpy as np import torch from astrai.config.train_config import TrainConfig -from astrai.data.serialization import Checkpoint +from astrai.serialization import Checkpoint from astrai.trainer.schedule import SchedulerFactory from astrai.trainer.trainer import Trainer