feat(data, inference): 使用chatML格式

This commit is contained in:
ViperEkura 2025-10-29 12:02:43 +08:00
parent 38b2725cd1
commit d94fc5a87a
8 changed files with 51 additions and 114 deletions

View File

@ -8,8 +8,6 @@ from khaosz.data.data_util import (
ResumeableRandomSampler, ResumeableRandomSampler,
DatasetLoader, DatasetLoader,
load_pkl_files, load_pkl_files,
build_attention_mask,
build_loss_mask
) )
from khaosz.data.tokenizer import BpeTokenizer from khaosz.data.tokenizer import BpeTokenizer
@ -24,7 +22,5 @@ __all__ = [
"ResumeableRandomSampler", "ResumeableRandomSampler",
"DatasetLoader", "DatasetLoader",
"load_pkl_files", "load_pkl_files",
"build_attention_mask",
"build_loss_mask",
"BpeTokenizer" "BpeTokenizer"
] ]

View File

@ -4,7 +4,7 @@ import pickle as pkl
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset, Sampler from torch.utils.data import Dataset, Sampler
from typing import Callable, List, Dict, Literal, Union from typing import Callable, List, Dict, Literal, Optional, Union
MutiSeg = Dict[str, List[Tensor]] MutiSeg = Dict[str, List[Tensor]]
Seg = Dict[str, Tensor] Seg = Dict[str, Tensor]
@ -25,36 +25,6 @@ def load_pkl_files(paths: List[str]):
return segments, total_samples return segments, total_samples
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
seq_len = input_ids.size(0)
turn_id = input_ids.eq(user_token_id).cumsum(dim=-1)
iq = turn_id.view(seq_len, 1)
ik = turn_id.view(1, seq_len)
# fix the causual attention mask(iq >= ik condition)
seq_mask = (iq >= ik) if multi_turn else (iq == ik)
attention_mask = torch.tril(seq_mask)
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
return attention_mask.unsqueeze(0)
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
is_bos_token = input_ids.eq(bos_token_id)
is_eos_token = input_ids.eq(eos_token_id)
# fix the eos_token_id bug(change target_ids to input_ids)
token_markers[is_bos_token] = 1
token_markers[is_eos_token] = -1
cumulative_markers = torch.cumsum(token_markers, dim=-1)
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
loss_mask = cumulative_markers - min_cumulative
return loss_mask.to(dtype=torch.bool)
class BaseSegmentFetcher: class BaseSegmentFetcher:
def __init__(self, segments: List[Tensor]): def __init__(self, segments: List[Tensor]):
@ -111,11 +81,12 @@ class MutiSegmentFetcher:
class BaseDataset(Dataset, ABC): class BaseDataset(Dataset, ABC):
def __init__(self, chunk_size: int): def __init__(self, chunk_size: int, step_size: int):
super().__init__() super().__init__()
self.segments: MutiSeg = {} self.segments: MutiSeg = {}
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.total_samples = 0 self.step_size = step_size
self.total_samples = None
def save(self, save_path: str): def save(self, save_path: str):
keys = list(self.segments.keys()) keys = list(self.segments.keys())
@ -140,16 +111,15 @@ class BaseDataset(Dataset, ABC):
raise NotImplementedError raise NotImplementedError
def __len__(self) -> int: def __len__(self) -> int:
assert self.total_samples // self.chunk_size > 0 assert self.total_samples is not None
return self.total_samples // self.chunk_size if self.total_samples < self.chunk_size:
return 0
return (self.total_samples - self.chunk_size) // self.step_size + 1
class SeqDataset(BaseDataset): class SeqDataset(BaseDataset):
def __init__( def __init__(self, chunk_size: int, step_size: int):
self, super().__init__(chunk_size, step_size)
chunk_size,
):
super().__init__(chunk_size)
self.fetcher = MutiSegmentFetcher(self.segments) self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
@ -167,41 +137,27 @@ class SeqDataset(BaseDataset):
class SftDataset(BaseDataset): class SftDataset(BaseDataset):
def __init__( def __init__(self, chunk_size: int, step_size: int):
self, super().__init__(chunk_size, step_size)
chunk_size,
bos_token_id,
eos_token_id,
user_token_id,
multi_turn=False,
):
super().__init__(chunk_size)
self.fetcher = MutiSegmentFetcher(self.segments) self.fetcher = MutiSegmentFetcher(self.segments)
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
self.user_token_id = user_token_id return self.fetcher.key_fetch(begin_idx, end_idx, key)
self.multi_turn = multi_turn
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
def __getitem__(self, index): def __getitem__(self, index):
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
end_idx = begin_idx + self.chunk_size end_idx = begin_idx + self.chunk_size
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool)
# fix the eos_token_id bug(change target_ids to input_ids) return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)
attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn)
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}
class DpoDataset(BaseDataset): class DpoDataset(BaseDataset):
def __init__(self, chunk_size: int): def __init__(self, chunk_size: int, step_size: int):
super().__init__(chunk_size) super().__init__(chunk_size, step_size)
self.fetcher = MutiSegmentFetcher(self.segments) self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
@ -220,8 +176,8 @@ class DpoDataset(BaseDataset):
class PpoDataset(BaseDataset): class PpoDataset(BaseDataset):
def __init__(self, chunk_size: int): def __init__(self, chunk_size: int, step_size: int):
super().__init__(chunk_size) super().__init__(chunk_size, step_size)
self.fetcher = MutiSegmentFetcher(self.segments) self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
@ -245,19 +201,16 @@ class DatasetLoader:
train_type: Literal["seq", "sft", "dpo"], train_type: Literal["seq", "sft", "dpo"],
load_path: Union[str, List[str]], load_path: Union[str, List[str]],
max_len: int, max_len: int,
step_size: Optional[int] = None,
**kwargs **kwargs
) -> BaseDataset: ) -> BaseDataset:
if step_size is None:
step_size = max_len
dataset_router: Dict[str, Callable[[int], BaseDataset]] = { dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
"seq": lambda max_len: SeqDataset(max_len), "seq": lambda max_len: SeqDataset(max_len, step_size),
"sft": lambda max_len: SftDataset( "sft": lambda max_len: SftDataset(max_len, step_size),
max_len, "dpo": lambda max_len: DpoDataset(max_len, step_size),
bos_token_id=kwargs.get("bos_token_id"),
eos_token_id=kwargs.get("eos_token_id"),
user_token_id=kwargs.get("user_token_id"),
multi_turn=kwargs.get("multi_turn")
),
"dpo": lambda max_len: DpoDataset(max_len),
} }
dataset = dataset_router[train_type](max_len) dataset = dataset_router[train_type](max_len)
dataset.load(load_path) dataset.load(load_path)

View File

@ -8,7 +8,7 @@ from typing import List, Union
class BpeTokenizer: class BpeTokenizer:
def __init__(self, path=None): def __init__(self, path=None):
self._control_tokens = ["<bos>", "<eos>", "<pad>"] self._control_tokens = ["<bos>", "<eos>", "<pad>"]
self._special_tokens = ["<|user|>", "<|system|>"] self._special_tokens = ["<|im_start|>", "<|im_end|>"]
model = BPE() model = BPE()
tokenizer = Tokenizer(model) tokenizer = Tokenizer(model)
tokenizer.normalizer = normalizers.Sequence([ tokenizer.normalizer = normalizers.Sequence([
@ -93,9 +93,7 @@ class BpeTokenizer:
@property @property
def stop_ids(self) -> List[int]: def stop_ids(self) -> List[int]:
stop_ids = [] stop_ids = self._control_tokens + self._special_tokens
for token in self._control_tokens:
stop_ids.append(self._tokenizer.token_to_id(token))
return stop_ids return stop_ids
@property @property
@ -108,12 +106,4 @@ class BpeTokenizer:
@property @property
def pad_id(self) -> int: def pad_id(self) -> int:
return self._tokenizer.token_to_id("<pad>") return self._tokenizer.token_to_id("<pad>")
@property
def user_id(self) -> int:
return self._tokenizer.token_to_id("<|user|>")
@property
def system_id(self) -> int:
return self._tokenizer.token_to_id("<|system|>")

View File

@ -5,30 +5,34 @@ from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheMa
from khaosz.config.param_config import ModelParameter from khaosz.config.param_config import ModelParameter
def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str: def build_prompt(
query: str,
init_prompt: Optional[str] = None,
history: Optional[List[Tuple[str, str]]] = None
) -> str:
""" """
Build prompt for query and history Build prompt in ChatML format for query and history
Args: Args:
query(str): query string query(str): query string
history(Optional[List[Tuple[str, str]]]): history list of query and response history(Optional[List[Tuple[str, str]]]): history list of query and response
Returns: Returns:
str: prompt string str: prompt string in ChatML format
""" """
prompt_parts = [] prompt = f"<|im_start|>system\n{init_prompt}<|im_end|>\n" if init_prompt else ""
if history is None: # (convert tuple format to ChatML)
history = [] if history:
for user_msg, assistant_msg in history:
prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
for his_query, his_response in history: prompt += f"<|im_start|>user\n{query}<|im_end|>\n"
prompt_parts.append(f"<|user|> {his_query} <|system|> <bos>{his_response}<eos>") prompt += "<|im_start|>assistant\n"
if query is not None:
prompt_parts.append(f"<|user|> {query} <|system|> <bos>")
return "\n".join(prompt_parts) return prompt
def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]: def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]:
""" """

View File

@ -68,11 +68,10 @@ class SftStrategy(BaseStrategy):
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device) batch = move_to_device(batch, self.device)
input_ids, target_ids = batch["input_ids"], batch["target_ids"] input_ids, target_ids, loss_mask = batch["input_ids"], batch["target_ids"], batch["loss_mask"]
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
ignore_index = -100 ignore_index = -100
logits = self.model(input_ids=input_ids, input_mask=attn_mask)["logits"] logits = self.model(input_ids=input_ids)["logits"]
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
loss = F.cross_entropy( loss = F.cross_entropy(

View File

@ -10,7 +10,6 @@ import matplotlib
from torch.utils.data import Dataset from torch.utils.data import Dataset
from khaosz.config.model_config import TransformerConfig from khaosz.config.model_config import TransformerConfig
from khaosz.data.data_util import build_attention_mask, build_loss_mask
from khaosz.data.tokenizer import BpeTokenizer from khaosz.data.tokenizer import BpeTokenizer
from khaosz.model.transformer import Transformer from khaosz.model.transformer import Transformer
@ -46,14 +45,12 @@ class MultiTurnDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
input_ids = torch.randint(0, self.vocab_size, (self.max_length,)) input_ids = torch.randint(0, self.vocab_size, (self.max_length,))
target_ids = torch.randint(0, self.vocab_size, (self.max_length,)) target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
loss_mask = build_loss_mask(input_ids, 0, 1) loss_mask = torch.randint(0, 1, (self.max_length,))
attn_mask = build_attention_mask(input_ids, 2, True)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"target_ids": target_ids, "target_ids": target_ids,
"loss_mask": loss_mask, "loss_mask": loss_mask,
"attn_mask": attn_mask,
} }

View File

@ -31,7 +31,6 @@ def test_multi_turn_training(base_test_env, multi_turn_dataset):
base_test_env["device"], base_test_env["device"],
bos_token_id=2, bos_token_id=2,
eos_token_id=3, eos_token_id=3,
user_token_id=1,
multi_turn=True multi_turn=True
) )

View File

@ -61,7 +61,6 @@ def train(
"bos_token_id": parameter.tokenizer.bos_id, "bos_token_id": parameter.tokenizer.bos_id,
"eos_token_id": parameter.tokenizer.eos_id, "eos_token_id": parameter.tokenizer.eos_id,
"pad_token_id": parameter.tokenizer.pad_id, "pad_token_id": parameter.tokenizer.pad_id,
"user_token_id":parameter.tokenizer.user_id,
} }
strategy = StrategyFactory.load( strategy = StrategyFactory.load(