feat(data, inference): 使用chatML格式

This commit is contained in:
ViperEkura 2025-10-29 12:02:43 +08:00
parent 38b2725cd1
commit d94fc5a87a
8 changed files with 51 additions and 114 deletions

View File

@ -8,8 +8,6 @@ from khaosz.data.data_util import (
ResumeableRandomSampler,
DatasetLoader,
load_pkl_files,
build_attention_mask,
build_loss_mask
)
from khaosz.data.tokenizer import BpeTokenizer
@ -24,7 +22,5 @@ __all__ = [
"ResumeableRandomSampler",
"DatasetLoader",
"load_pkl_files",
"build_attention_mask",
"build_loss_mask",
"BpeTokenizer"
]

View File

@ -4,7 +4,7 @@ import pickle as pkl
from abc import ABC, abstractmethod
from torch import Tensor
from torch.utils.data import Dataset, Sampler
from typing import Callable, List, Dict, Literal, Union
from typing import Callable, List, Dict, Literal, Optional, Union
MutiSeg = Dict[str, List[Tensor]]
Seg = Dict[str, Tensor]
@ -25,36 +25,6 @@ def load_pkl_files(paths: List[str]):
return segments, total_samples
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
seq_len = input_ids.size(0)
turn_id = input_ids.eq(user_token_id).cumsum(dim=-1)
iq = turn_id.view(seq_len, 1)
ik = turn_id.view(1, seq_len)
# fix the causual attention mask(iq >= ik condition)
seq_mask = (iq >= ik) if multi_turn else (iq == ik)
attention_mask = torch.tril(seq_mask)
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
return attention_mask.unsqueeze(0)
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
is_bos_token = input_ids.eq(bos_token_id)
is_eos_token = input_ids.eq(eos_token_id)
# fix the eos_token_id bug(change target_ids to input_ids)
token_markers[is_bos_token] = 1
token_markers[is_eos_token] = -1
cumulative_markers = torch.cumsum(token_markers, dim=-1)
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
loss_mask = cumulative_markers - min_cumulative
return loss_mask.to(dtype=torch.bool)
class BaseSegmentFetcher:
def __init__(self, segments: List[Tensor]):
@ -111,11 +81,12 @@ class MutiSegmentFetcher:
class BaseDataset(Dataset, ABC):
def __init__(self, chunk_size: int):
def __init__(self, chunk_size: int, step_size: int):
super().__init__()
self.segments: MutiSeg = {}
self.chunk_size = chunk_size
self.total_samples = 0
self.step_size = step_size
self.total_samples = None
def save(self, save_path: str):
keys = list(self.segments.keys())
@ -140,16 +111,15 @@ class BaseDataset(Dataset, ABC):
raise NotImplementedError
def __len__(self) -> int:
assert self.total_samples // self.chunk_size > 0
return self.total_samples // self.chunk_size
assert self.total_samples is not None
if self.total_samples < self.chunk_size:
return 0
return (self.total_samples - self.chunk_size) // self.step_size + 1
class SeqDataset(BaseDataset):
def __init__(
self,
chunk_size,
):
super().__init__(chunk_size)
def __init__(self, chunk_size: int, step_size: int):
super().__init__(chunk_size, step_size)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
@ -167,41 +137,27 @@ class SeqDataset(BaseDataset):
class SftDataset(BaseDataset):
def __init__(
self,
chunk_size,
bos_token_id,
eos_token_id,
user_token_id,
multi_turn=False,
):
super().__init__(chunk_size)
def __init__(self, chunk_size: int, step_size: int):
super().__init__(chunk_size, step_size)
self.fetcher = MutiSegmentFetcher(self.segments)
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.user_token_id = user_token_id
self.multi_turn = multi_turn
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index):
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
end_idx = begin_idx + self.chunk_size
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool)
# fix the eos_token_id bug(change target_ids to input_ids)
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)
attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn)
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
class DpoDataset(BaseDataset):
def __init__(self, chunk_size: int):
super().__init__(chunk_size)
def __init__(self, chunk_size: int, step_size: int):
super().__init__(chunk_size, step_size)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
@ -220,8 +176,8 @@ class DpoDataset(BaseDataset):
class PpoDataset(BaseDataset):
def __init__(self, chunk_size: int):
super().__init__(chunk_size)
def __init__(self, chunk_size: int, step_size: int):
super().__init__(chunk_size, step_size)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
@ -245,19 +201,16 @@ class DatasetLoader:
train_type: Literal["seq", "sft", "dpo"],
load_path: Union[str, List[str]],
max_len: int,
step_size: Optional[int] = None,
**kwargs
) -> BaseDataset:
if step_size is None:
step_size = max_len
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
"seq": lambda max_len: SeqDataset(max_len),
"sft": lambda max_len: SftDataset(
max_len,
bos_token_id=kwargs.get("bos_token_id"),
eos_token_id=kwargs.get("eos_token_id"),
user_token_id=kwargs.get("user_token_id"),
multi_turn=kwargs.get("multi_turn")
),
"dpo": lambda max_len: DpoDataset(max_len),
"seq": lambda max_len: SeqDataset(max_len, step_size),
"sft": lambda max_len: SftDataset(max_len, step_size),
"dpo": lambda max_len: DpoDataset(max_len, step_size),
}
dataset = dataset_router[train_type](max_len)
dataset.load(load_path)

View File

@ -8,7 +8,7 @@ from typing import List, Union
class BpeTokenizer:
def __init__(self, path=None):
self._control_tokens = ["<bos>", "<eos>", "<pad>"]
self._special_tokens = ["<|user|>", "<|system|>"]
self._special_tokens = ["<|im_start|>", "<|im_end|>"]
model = BPE()
tokenizer = Tokenizer(model)
tokenizer.normalizer = normalizers.Sequence([
@ -93,9 +93,7 @@ class BpeTokenizer:
@property
def stop_ids(self) -> List[int]:
stop_ids = []
for token in self._control_tokens:
stop_ids.append(self._tokenizer.token_to_id(token))
stop_ids = self._control_tokens + self._special_tokens
return stop_ids
@property
@ -108,12 +106,4 @@ class BpeTokenizer:
@property
def pad_id(self) -> int:
return self._tokenizer.token_to_id("<pad>")
@property
def user_id(self) -> int:
return self._tokenizer.token_to_id("<|user|>")
@property
def system_id(self) -> int:
return self._tokenizer.token_to_id("<|system|>")
return self._tokenizer.token_to_id("<pad>")

View File

@ -5,30 +5,34 @@ from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheMa
from khaosz.config.param_config import ModelParameter
def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str:
def build_prompt(
query: str,
init_prompt: Optional[str] = None,
history: Optional[List[Tuple[str, str]]] = None
) -> str:
"""
Build prompt for query and history
Build prompt in ChatML format for query and history
Args:
query(str): query string
history(Optional[List[Tuple[str, str]]]): history list of query and response
Returns:
str: prompt string
str: prompt string in ChatML format
"""
prompt_parts = []
prompt = f"<|im_start|>system\n{init_prompt}<|im_end|>\n" if init_prompt else ""
if history is None:
history = []
# (convert tuple format to ChatML)
if history:
for user_msg, assistant_msg in history:
prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
for his_query, his_response in history:
prompt_parts.append(f"<|user|> {his_query} <|system|> <bos>{his_response}<eos>")
if query is not None:
prompt_parts.append(f"<|user|> {query} <|system|> <bos>")
prompt += f"<|im_start|>user\n{query}<|im_end|>\n"
prompt += "<|im_start|>assistant\n"
return "\n".join(prompt_parts)
return prompt
def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]:
"""

View File

@ -68,11 +68,10 @@ class SftStrategy(BaseStrategy):
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
input_ids, target_ids, loss_mask = batch["input_ids"], batch["target_ids"], batch["loss_mask"]
ignore_index = -100
logits = self.model(input_ids=input_ids, input_mask=attn_mask)["logits"]
logits = self.model(input_ids=input_ids)["logits"]
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
loss = F.cross_entropy(

View File

@ -10,7 +10,6 @@ import matplotlib
from torch.utils.data import Dataset
from khaosz.config.model_config import TransformerConfig
from khaosz.data.data_util import build_attention_mask, build_loss_mask
from khaosz.data.tokenizer import BpeTokenizer
from khaosz.model.transformer import Transformer
@ -46,14 +45,12 @@ class MultiTurnDataset(Dataset):
def __getitem__(self, idx):
input_ids = torch.randint(0, self.vocab_size, (self.max_length,))
target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
loss_mask = build_loss_mask(input_ids, 0, 1)
attn_mask = build_attention_mask(input_ids, 2, True)
loss_mask = torch.randint(0, 1, (self.max_length,))
return {
"input_ids": input_ids,
"target_ids": target_ids,
"loss_mask": loss_mask,
"attn_mask": attn_mask,
}

View File

@ -31,7 +31,6 @@ def test_multi_turn_training(base_test_env, multi_turn_dataset):
base_test_env["device"],
bos_token_id=2,
eos_token_id=3,
user_token_id=1,
multi_turn=True
)

View File

@ -61,7 +61,6 @@ def train(
"bos_token_id": parameter.tokenizer.bos_id,
"eos_token_id": parameter.tokenizer.eos_id,
"pad_token_id": parameter.tokenizer.pad_id,
"user_token_id":parameter.tokenizer.user_id,
}
strategy = StrategyFactory.load(