refactor: 从data 模块分离tokenizer
This commit is contained in:
parent
b531232a9b
commit
bd9741dc5f
|
|
@ -6,7 +6,8 @@ from astrai.config import (
|
|||
TrainConfig,
|
||||
)
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.data import BpeTokenizer, DatasetFactory
|
||||
from astrai.data import DatasetFactory
|
||||
from astrai.tokenizer import BpeTokenizer
|
||||
from astrai.inference.generator import (
|
||||
BatchGenerator,
|
||||
EmbeddingEncoder,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import safetensors.torch as st
|
|||
import torch.nn as nn
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.data.tokenizer import BpeTokenizer
|
||||
from astrai.tokenizer import BpeTokenizer
|
||||
from astrai.model.transformer import Transformer
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,29 +1,19 @@
|
|||
from astrai.data.dataset import (
|
||||
BaseDataset,
|
||||
DatasetFactory,
|
||||
DPODataset,
|
||||
GRPODataset,
|
||||
BaseSegmentFetcher,
|
||||
MultiSegmentFetcher,
|
||||
SEQDataset,
|
||||
SFTDataset,
|
||||
)
|
||||
from astrai.data.sampler import ResumableDistributedSampler
|
||||
from astrai.data.tokenizer import BpeTokenizer
|
||||
|
||||
__all__ = [
|
||||
# Base classes
|
||||
"BaseDataset",
|
||||
# Dataset implementations
|
||||
"SEQDataset",
|
||||
"SFTDataset",
|
||||
"DPODataset",
|
||||
"GRPODataset",
|
||||
# Factory
|
||||
"DatasetFactory",
|
||||
# Fetchers
|
||||
"BaseSegmentFetcher",
|
||||
"MultiSegmentFetcher",
|
||||
# Factory (DatasetFactory is alias for backward compatibility)
|
||||
"DatasetFactory",
|
||||
"DatasetFactory",
|
||||
# Tokenizer and sampler
|
||||
"BpeTokenizer",
|
||||
# Sampler
|
||||
"ResumableDistributedSampler",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from torch import Tensor
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.data.serialization import load_h5
|
||||
from astrai.serialization import load_h5
|
||||
|
||||
|
||||
class BaseSegmentFetcher:
|
||||
|
|
|
|||
|
|
@ -1,78 +1,13 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, Generator, List, Optional, Tuple, Union
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from jinja2 import Template
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
|
||||
|
||||
HistoryType = List[Tuple[str, str]]
|
||||
MessageType = Dict[str, str]
|
||||
|
||||
# Predefined chat templates using jinja2
|
||||
CHAT_TEMPLATES: Dict[str, str] = {
|
||||
"chatml": """{%- if system_prompt -%}
|
||||
<|im▁start|>system
|
||||
{{ system_prompt }}<|im▁end|>
|
||||
{%- endif -%}
|
||||
{%- for message in messages -%}
|
||||
<|im▁start|>{{ message['role'] }}
|
||||
{{ message['content'] }}<|im▁end|>
|
||||
{%- endfor -%}
|
||||
<|im▁start|>assistant
|
||||
""",
|
||||
}
|
||||
|
||||
|
||||
def build_prompt(
|
||||
query: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
history: Optional[HistoryType] = None,
|
||||
template: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Build prompt using jinja2 template for query and history.
|
||||
|
||||
Args:
|
||||
query (str): query string.
|
||||
system_prompt (Optional[str]): system prompt string.
|
||||
history (Optional[HistoryType]): history list of query and response.
|
||||
template (Optional[str]): jinja2 template string. If None, uses default chatml template.
|
||||
|
||||
Returns:
|
||||
str: prompt string formatted according to the template.
|
||||
|
||||
Example:
|
||||
# Use default template
|
||||
prompt = build_prompt(query="Hello", history=[...])
|
||||
|
||||
# Use custom template
|
||||
custom_template = '''
|
||||
{%- for msg in messages -%}
|
||||
{{ msg['role'] }}: {{ msg['content'] }}
|
||||
{%- endfor -%}
|
||||
'''
|
||||
prompt = build_prompt(query="Hello", template=custom_template)
|
||||
"""
|
||||
# Convert history to message format
|
||||
messages: List[MessageType] = []
|
||||
if history:
|
||||
for user_msg, assistant_msg in history:
|
||||
messages.append({"role": "user", "content": user_msg})
|
||||
messages.append({"role": "assistant", "content": assistant_msg})
|
||||
messages.append({"role": "user", "content": query})
|
||||
|
||||
# Use provided template or default chatml template
|
||||
template_str = template if template is not None else CHAT_TEMPLATES["chatml"]
|
||||
|
||||
# Render template
|
||||
jinja_template = Template(template_str)
|
||||
return jinja_template.render(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
from astrai.tokenizer.chat_template import HistoryType, build_prompt
|
||||
|
||||
|
||||
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
from astrai.tokenizer.tokenizer import (
|
||||
BaseTokenizer,
|
||||
BpeTokenizer,
|
||||
BaseTrainer,
|
||||
BpeTrainer,
|
||||
)
|
||||
from astrai.tokenizer.chat_template import (
|
||||
HistoryType,
|
||||
MessageType,
|
||||
CHAT_TEMPLATES,
|
||||
build_prompt,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseTokenizer",
|
||||
"BpeTokenizer",
|
||||
"BaseTrainer",
|
||||
"BpeTrainer",
|
||||
"HistoryType",
|
||||
"MessageType",
|
||||
"CHAT_TEMPLATES",
|
||||
"build_prompt",
|
||||
]
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
from jinja2 import Template
|
||||
|
||||
HistoryType = List[Tuple[str, str]]
|
||||
MessageType = Dict[str, str]
|
||||
|
||||
# Predefined chat templates using jinja2
|
||||
CHAT_TEMPLATES: Dict[str, str] = {
|
||||
"chatml": """{%- if system_prompt -%}
|
||||
<|im▁start|>system
|
||||
{{ system_prompt }}<|im▁end|>
|
||||
{%- endif -%}
|
||||
{%- for message in messages -%}
|
||||
<|im▁start|>{{ message['role'] }}
|
||||
{{ message['content'] }}<|im▁end|>
|
||||
{%- endfor -%}
|
||||
<|im▁start|>assistant
|
||||
""",
|
||||
}
|
||||
|
||||
|
||||
def build_prompt(
|
||||
query: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
history: Optional[HistoryType] = None,
|
||||
template: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Build prompt using jinja2 template for query and history.
|
||||
|
||||
Args:
|
||||
query (str): query string.
|
||||
system_prompt (Optional[str]): system prompt string.
|
||||
history (Optional[HistoryType]): history list of query and response.
|
||||
template (Optional[str]): jinja2 template string. If None, uses default chatml template.
|
||||
|
||||
Returns:
|
||||
str: prompt string formatted according to the template.
|
||||
|
||||
Example:
|
||||
# Use default template
|
||||
prompt = build_prompt(query="Hello", history=[...])
|
||||
|
||||
# Use custom template
|
||||
custom_template = '''
|
||||
{%- for msg in messages -%}
|
||||
{{ msg['role'] }}: {{ msg['content'] }}
|
||||
{%- endfor -%}
|
||||
'''
|
||||
prompt = build_prompt(query="Hello", template=custom_template)
|
||||
"""
|
||||
# Convert history to message format
|
||||
messages: List[MessageType] = []
|
||||
if history:
|
||||
for user_msg, assistant_msg in history:
|
||||
messages.append({"role": "user", "content": user_msg})
|
||||
messages.append({"role": "assistant", "content": assistant_msg})
|
||||
messages.append({"role": "user", "content": query})
|
||||
|
||||
# Use provided template or default chatml template
|
||||
template_str = template if template is not None else CHAT_TEMPLATES["chatml"]
|
||||
|
||||
# Render template
|
||||
jinja_template = Template(template_str)
|
||||
return jinja_template.render(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||
from torch.nn.utils import clip_grad_norm_
|
||||
from tqdm import tqdm
|
||||
|
||||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.parallel import only_on_rank
|
||||
from astrai.trainer.metric_util import (
|
||||
ctx_get_grad_max,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.data import ResumableDistributedSampler
|
||||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
from typing import List, Optional
|
||||
|
||||
from astrai.config import TrainConfig
|
||||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.parallel.setup import spawn_parallel_fn
|
||||
from astrai.trainer.train_callback import (
|
||||
TrainCallback,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from tokenizers import pre_tokenizers
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.data.tokenizer import BpeTokenizer, BpeTrainer
|
||||
from astrai.tokenizer import BpeTokenizer, BpeTrainer
|
||||
from astrai.model.transformer import Transformer
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch.distributed as dist
|
|||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.parallel.setup import get_rank, spawn_parallel_fn
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from astrai.data.dataset import DatasetFactory
|
||||
from astrai.data.serialization import save_h5
|
||||
from astrai.serialization import save_h5
|
||||
|
||||
|
||||
def test_dataset_loader_random_paths(base_test_env):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.schedule import SchedulerFactory
|
||||
from astrai.trainer.trainer import Trainer
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue