refactor: 从data 模块分离tokenizer
This commit is contained in:
parent
b531232a9b
commit
bd9741dc5f
|
|
@ -6,7 +6,8 @@ from astrai.config import (
|
||||||
TrainConfig,
|
TrainConfig,
|
||||||
)
|
)
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
from astrai.data import BpeTokenizer, DatasetFactory
|
from astrai.data import DatasetFactory
|
||||||
|
from astrai.tokenizer import BpeTokenizer
|
||||||
from astrai.inference.generator import (
|
from astrai.inference.generator import (
|
||||||
BatchGenerator,
|
BatchGenerator,
|
||||||
EmbeddingEncoder,
|
EmbeddingEncoder,
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import safetensors.torch as st
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
from astrai.data.tokenizer import BpeTokenizer
|
from astrai.tokenizer import BpeTokenizer
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,29 +1,19 @@
|
||||||
from astrai.data.dataset import (
|
from astrai.data.dataset import (
|
||||||
BaseDataset,
|
BaseDataset,
|
||||||
DatasetFactory,
|
DatasetFactory,
|
||||||
DPODataset,
|
BaseSegmentFetcher,
|
||||||
GRPODataset,
|
|
||||||
MultiSegmentFetcher,
|
MultiSegmentFetcher,
|
||||||
SEQDataset,
|
|
||||||
SFTDataset,
|
|
||||||
)
|
)
|
||||||
from astrai.data.sampler import ResumableDistributedSampler
|
from astrai.data.sampler import ResumableDistributedSampler
|
||||||
from astrai.data.tokenizer import BpeTokenizer
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Base classes
|
# Base classes
|
||||||
"BaseDataset",
|
"BaseDataset",
|
||||||
# Dataset implementations
|
# Factory
|
||||||
"SEQDataset",
|
"DatasetFactory",
|
||||||
"SFTDataset",
|
|
||||||
"DPODataset",
|
|
||||||
"GRPODataset",
|
|
||||||
# Fetchers
|
# Fetchers
|
||||||
|
"BaseSegmentFetcher",
|
||||||
"MultiSegmentFetcher",
|
"MultiSegmentFetcher",
|
||||||
# Factory (DatasetFactory is alias for backward compatibility)
|
# Sampler
|
||||||
"DatasetFactory",
|
|
||||||
"DatasetFactory",
|
|
||||||
# Tokenizer and sampler
|
|
||||||
"BpeTokenizer",
|
|
||||||
"ResumableDistributedSampler",
|
"ResumableDistributedSampler",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from torch import Tensor
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
from astrai.data.serialization import load_h5
|
from astrai.serialization import load_h5
|
||||||
|
|
||||||
|
|
||||||
class BaseSegmentFetcher:
|
class BaseSegmentFetcher:
|
||||||
|
|
|
||||||
|
|
@ -1,78 +1,13 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Generator, List, Optional, Tuple, Union
|
from typing import Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from jinja2 import Template
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
|
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
|
||||||
|
from astrai.tokenizer.chat_template import HistoryType, build_prompt
|
||||||
HistoryType = List[Tuple[str, str]]
|
|
||||||
MessageType = Dict[str, str]
|
|
||||||
|
|
||||||
# Predefined chat templates using jinja2
|
|
||||||
CHAT_TEMPLATES: Dict[str, str] = {
|
|
||||||
"chatml": """{%- if system_prompt -%}
|
|
||||||
<|im▁start|>system
|
|
||||||
{{ system_prompt }}<|im▁end|>
|
|
||||||
{%- endif -%}
|
|
||||||
{%- for message in messages -%}
|
|
||||||
<|im▁start|>{{ message['role'] }}
|
|
||||||
{{ message['content'] }}<|im▁end|>
|
|
||||||
{%- endfor -%}
|
|
||||||
<|im▁start|>assistant
|
|
||||||
""",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def build_prompt(
|
|
||||||
query: str,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
history: Optional[HistoryType] = None,
|
|
||||||
template: Optional[str] = None,
|
|
||||||
) -> str:
|
|
||||||
"""Build prompt using jinja2 template for query and history.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (str): query string.
|
|
||||||
system_prompt (Optional[str]): system prompt string.
|
|
||||||
history (Optional[HistoryType]): history list of query and response.
|
|
||||||
template (Optional[str]): jinja2 template string. If None, uses default chatml template.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: prompt string formatted according to the template.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# Use default template
|
|
||||||
prompt = build_prompt(query="Hello", history=[...])
|
|
||||||
|
|
||||||
# Use custom template
|
|
||||||
custom_template = '''
|
|
||||||
{%- for msg in messages -%}
|
|
||||||
{{ msg['role'] }}: {{ msg['content'] }}
|
|
||||||
{%- endfor -%}
|
|
||||||
'''
|
|
||||||
prompt = build_prompt(query="Hello", template=custom_template)
|
|
||||||
"""
|
|
||||||
# Convert history to message format
|
|
||||||
messages: List[MessageType] = []
|
|
||||||
if history:
|
|
||||||
for user_msg, assistant_msg in history:
|
|
||||||
messages.append({"role": "user", "content": user_msg})
|
|
||||||
messages.append({"role": "assistant", "content": assistant_msg})
|
|
||||||
messages.append({"role": "user", "content": query})
|
|
||||||
|
|
||||||
# Use provided template or default chatml template
|
|
||||||
template_str = template if template is not None else CHAT_TEMPLATES["chatml"]
|
|
||||||
|
|
||||||
# Render template
|
|
||||||
jinja_template = Template(template_str)
|
|
||||||
return jinja_template.render(
|
|
||||||
messages=messages,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
|
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
from astrai.tokenizer.tokenizer import (
|
||||||
|
BaseTokenizer,
|
||||||
|
BpeTokenizer,
|
||||||
|
BaseTrainer,
|
||||||
|
BpeTrainer,
|
||||||
|
)
|
||||||
|
from astrai.tokenizer.chat_template import (
|
||||||
|
HistoryType,
|
||||||
|
MessageType,
|
||||||
|
CHAT_TEMPLATES,
|
||||||
|
build_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseTokenizer",
|
||||||
|
"BpeTokenizer",
|
||||||
|
"BaseTrainer",
|
||||||
|
"BpeTrainer",
|
||||||
|
"HistoryType",
|
||||||
|
"MessageType",
|
||||||
|
"CHAT_TEMPLATES",
|
||||||
|
"build_prompt",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from jinja2 import Template
|
||||||
|
|
||||||
|
HistoryType = List[Tuple[str, str]]
|
||||||
|
MessageType = Dict[str, str]
|
||||||
|
|
||||||
|
# Predefined chat templates using jinja2
|
||||||
|
CHAT_TEMPLATES: Dict[str, str] = {
|
||||||
|
"chatml": """{%- if system_prompt -%}
|
||||||
|
<|im▁start|>system
|
||||||
|
{{ system_prompt }}<|im▁end|>
|
||||||
|
{%- endif -%}
|
||||||
|
{%- for message in messages -%}
|
||||||
|
<|im▁start|>{{ message['role'] }}
|
||||||
|
{{ message['content'] }}<|im▁end|>
|
||||||
|
{%- endfor -%}
|
||||||
|
<|im▁start|>assistant
|
||||||
|
""",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_prompt(
|
||||||
|
query: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
history: Optional[HistoryType] = None,
|
||||||
|
template: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build prompt using jinja2 template for query and history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): query string.
|
||||||
|
system_prompt (Optional[str]): system prompt string.
|
||||||
|
history (Optional[HistoryType]): history list of query and response.
|
||||||
|
template (Optional[str]): jinja2 template string. If None, uses default chatml template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: prompt string formatted according to the template.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Use default template
|
||||||
|
prompt = build_prompt(query="Hello", history=[...])
|
||||||
|
|
||||||
|
# Use custom template
|
||||||
|
custom_template = '''
|
||||||
|
{%- for msg in messages -%}
|
||||||
|
{{ msg['role'] }}: {{ msg['content'] }}
|
||||||
|
{%- endfor -%}
|
||||||
|
'''
|
||||||
|
prompt = build_prompt(query="Hello", template=custom_template)
|
||||||
|
"""
|
||||||
|
# Convert history to message format
|
||||||
|
messages: List[MessageType] = []
|
||||||
|
if history:
|
||||||
|
for user_msg, assistant_msg in history:
|
||||||
|
messages.append({"role": "user", "content": user_msg})
|
||||||
|
messages.append({"role": "assistant", "content": assistant_msg})
|
||||||
|
messages.append({"role": "user", "content": query})
|
||||||
|
|
||||||
|
# Use provided template or default chatml template
|
||||||
|
template_str = template if template is not None else CHAT_TEMPLATES["chatml"]
|
||||||
|
|
||||||
|
# Render template
|
||||||
|
jinja_template = Template(template_str)
|
||||||
|
return jinja_template.render(
|
||||||
|
messages=messages,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
)
|
||||||
|
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from astrai.data.serialization import Checkpoint
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.parallel import only_on_rank
|
from astrai.parallel import only_on_rank
|
||||||
from astrai.trainer.metric_util import (
|
from astrai.trainer.metric_util import (
|
||||||
ctx_get_grad_max,
|
ctx_get_grad_max,
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
from astrai.data import ResumableDistributedSampler
|
from astrai.data import ResumableDistributedSampler
|
||||||
from astrai.data.serialization import Checkpoint
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from astrai.config import TrainConfig
|
from astrai.config import TrainConfig
|
||||||
from astrai.data.serialization import Checkpoint
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.parallel.setup import spawn_parallel_fn
|
from astrai.parallel.setup import spawn_parallel_fn
|
||||||
from astrai.trainer.train_callback import (
|
from astrai.trainer.train_callback import (
|
||||||
TrainCallback,
|
TrainCallback,
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from tokenizers import pre_tokenizers
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
from astrai.data.tokenizer import BpeTokenizer, BpeTrainer
|
from astrai.tokenizer import BpeTokenizer, BpeTrainer
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import torch.distributed as dist
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
from astrai.data.serialization import Checkpoint
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.parallel.setup import get_rank, spawn_parallel_fn
|
from astrai.parallel.setup import get_rank, spawn_parallel_fn
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.data.dataset import DatasetFactory
|
from astrai.data.dataset import DatasetFactory
|
||||||
from astrai.data.serialization import save_h5
|
from astrai.serialization import save_h5
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_loader_random_paths(base_test_env):
|
def test_dataset_loader_random_paths(base_test_env):
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
from astrai.data.serialization import Checkpoint
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.trainer.schedule import SchedulerFactory
|
from astrai.trainer.schedule import SchedulerFactory
|
||||||
from astrai.trainer.trainer import Trainer
|
from astrai.trainer.trainer import Trainer
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue