feat(data, inference): 使用chatML格式
This commit is contained in:
parent
38b2725cd1
commit
d94fc5a87a
|
|
@ -8,8 +8,6 @@ from khaosz.data.data_util import (
|
|||
ResumeableRandomSampler,
|
||||
DatasetLoader,
|
||||
load_pkl_files,
|
||||
build_attention_mask,
|
||||
build_loss_mask
|
||||
)
|
||||
|
||||
from khaosz.data.tokenizer import BpeTokenizer
|
||||
|
|
@ -24,7 +22,5 @@ __all__ = [
|
|||
"ResumeableRandomSampler",
|
||||
"DatasetLoader",
|
||||
"load_pkl_files",
|
||||
"build_attention_mask",
|
||||
"build_loss_mask",
|
||||
"BpeTokenizer"
|
||||
]
|
||||
|
|
@ -4,7 +4,7 @@ import pickle as pkl
|
|||
from abc import ABC, abstractmethod
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
from typing import Callable, List, Dict, Literal, Union
|
||||
from typing import Callable, List, Dict, Literal, Optional, Union
|
||||
|
||||
MutiSeg = Dict[str, List[Tensor]]
|
||||
Seg = Dict[str, Tensor]
|
||||
|
|
@ -25,36 +25,6 @@ def load_pkl_files(paths: List[str]):
|
|||
|
||||
return segments, total_samples
|
||||
|
||||
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
|
||||
seq_len = input_ids.size(0)
|
||||
turn_id = input_ids.eq(user_token_id).cumsum(dim=-1)
|
||||
|
||||
iq = turn_id.view(seq_len, 1)
|
||||
ik = turn_id.view(1, seq_len)
|
||||
|
||||
# fix the causual attention mask(iq >= ik condition)
|
||||
seq_mask = (iq >= ik) if multi_turn else (iq == ik)
|
||||
attention_mask = torch.tril(seq_mask)
|
||||
|
||||
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
|
||||
return attention_mask.unsqueeze(0)
|
||||
|
||||
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
|
||||
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
||||
|
||||
is_bos_token = input_ids.eq(bos_token_id)
|
||||
is_eos_token = input_ids.eq(eos_token_id)
|
||||
|
||||
# fix the eos_token_id bug(change target_ids to input_ids)
|
||||
token_markers[is_bos_token] = 1
|
||||
token_markers[is_eos_token] = -1
|
||||
|
||||
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
||||
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
||||
loss_mask = cumulative_markers - min_cumulative
|
||||
|
||||
return loss_mask.to(dtype=torch.bool)
|
||||
|
||||
|
||||
class BaseSegmentFetcher:
|
||||
def __init__(self, segments: List[Tensor]):
|
||||
|
|
@ -111,11 +81,12 @@ class MutiSegmentFetcher:
|
|||
|
||||
|
||||
class BaseDataset(Dataset, ABC):
|
||||
def __init__(self, chunk_size: int):
|
||||
def __init__(self, chunk_size: int, step_size: int):
|
||||
super().__init__()
|
||||
self.segments: MutiSeg = {}
|
||||
self.chunk_size = chunk_size
|
||||
self.total_samples = 0
|
||||
self.step_size = step_size
|
||||
self.total_samples = None
|
||||
|
||||
def save(self, save_path: str):
|
||||
keys = list(self.segments.keys())
|
||||
|
|
@ -140,16 +111,15 @@ class BaseDataset(Dataset, ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
assert self.total_samples // self.chunk_size > 0
|
||||
return self.total_samples // self.chunk_size
|
||||
assert self.total_samples is not None
|
||||
if self.total_samples < self.chunk_size:
|
||||
return 0
|
||||
return (self.total_samples - self.chunk_size) // self.step_size + 1
|
||||
|
||||
|
||||
class SeqDataset(BaseDataset):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size,
|
||||
):
|
||||
super().__init__(chunk_size)
|
||||
def __init__(self, chunk_size: int, step_size: int):
|
||||
super().__init__(chunk_size, step_size)
|
||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
|
|
@ -167,41 +137,27 @@ class SeqDataset(BaseDataset):
|
|||
|
||||
|
||||
class SftDataset(BaseDataset):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size,
|
||||
bos_token_id,
|
||||
eos_token_id,
|
||||
user_token_id,
|
||||
multi_turn=False,
|
||||
):
|
||||
super().__init__(chunk_size)
|
||||
def __init__(self, chunk_size: int, step_size: int):
|
||||
super().__init__(chunk_size, step_size)
|
||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.user_token_id = user_token_id
|
||||
self.multi_turn = multi_turn
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index):
|
||||
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
||||
end_idx = begin_idx + self.chunk_size
|
||||
|
||||
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
|
||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool)
|
||||
|
||||
# fix the eos_token_id bug(change target_ids to input_ids)
|
||||
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)
|
||||
attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn)
|
||||
|
||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}
|
||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
||||
|
||||
|
||||
class DpoDataset(BaseDataset):
|
||||
def __init__(self, chunk_size: int):
|
||||
super().__init__(chunk_size)
|
||||
def __init__(self, chunk_size: int, step_size: int):
|
||||
super().__init__(chunk_size, step_size)
|
||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
|
|
@ -220,8 +176,8 @@ class DpoDataset(BaseDataset):
|
|||
|
||||
|
||||
class PpoDataset(BaseDataset):
|
||||
def __init__(self, chunk_size: int):
|
||||
super().__init__(chunk_size)
|
||||
def __init__(self, chunk_size: int, step_size: int):
|
||||
super().__init__(chunk_size, step_size)
|
||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
|
|
@ -245,19 +201,16 @@ class DatasetLoader:
|
|||
train_type: Literal["seq", "sft", "dpo"],
|
||||
load_path: Union[str, List[str]],
|
||||
max_len: int,
|
||||
step_size: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> BaseDataset:
|
||||
if step_size is None:
|
||||
step_size = max_len
|
||||
|
||||
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
|
||||
"seq": lambda max_len: SeqDataset(max_len),
|
||||
"sft": lambda max_len: SftDataset(
|
||||
max_len,
|
||||
bos_token_id=kwargs.get("bos_token_id"),
|
||||
eos_token_id=kwargs.get("eos_token_id"),
|
||||
user_token_id=kwargs.get("user_token_id"),
|
||||
multi_turn=kwargs.get("multi_turn")
|
||||
),
|
||||
"dpo": lambda max_len: DpoDataset(max_len),
|
||||
"seq": lambda max_len: SeqDataset(max_len, step_size),
|
||||
"sft": lambda max_len: SftDataset(max_len, step_size),
|
||||
"dpo": lambda max_len: DpoDataset(max_len, step_size),
|
||||
}
|
||||
dataset = dataset_router[train_type](max_len)
|
||||
dataset.load(load_path)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import List, Union
|
|||
class BpeTokenizer:
|
||||
def __init__(self, path=None):
|
||||
self._control_tokens = ["<bos>", "<eos>", "<pad>"]
|
||||
self._special_tokens = ["<|user|>", "<|system|>"]
|
||||
self._special_tokens = ["<|im_start|>", "<|im_end|>"]
|
||||
model = BPE()
|
||||
tokenizer = Tokenizer(model)
|
||||
tokenizer.normalizer = normalizers.Sequence([
|
||||
|
|
@ -93,9 +93,7 @@ class BpeTokenizer:
|
|||
|
||||
@property
|
||||
def stop_ids(self) -> List[int]:
|
||||
stop_ids = []
|
||||
for token in self._control_tokens:
|
||||
stop_ids.append(self._tokenizer.token_to_id(token))
|
||||
stop_ids = self._control_tokens + self._special_tokens
|
||||
return stop_ids
|
||||
|
||||
@property
|
||||
|
|
@ -109,11 +107,3 @@ class BpeTokenizer:
|
|||
@property
|
||||
def pad_id(self) -> int:
|
||||
return self._tokenizer.token_to_id("<pad>")
|
||||
|
||||
@property
|
||||
def user_id(self) -> int:
|
||||
return self._tokenizer.token_to_id("<|user|>")
|
||||
|
||||
@property
|
||||
def system_id(self) -> int:
|
||||
return self._tokenizer.token_to_id("<|system|>")
|
||||
|
|
@ -5,30 +5,34 @@ from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheMa
|
|||
from khaosz.config.param_config import ModelParameter
|
||||
|
||||
|
||||
def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str:
|
||||
def build_prompt(
|
||||
query: str,
|
||||
init_prompt: Optional[str] = None,
|
||||
history: Optional[List[Tuple[str, str]]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Build prompt for query and history
|
||||
Build prompt in ChatML format for query and history
|
||||
|
||||
Args:
|
||||
query(str): query string
|
||||
history(Optional[List[Tuple[str, str]]]): history list of query and response
|
||||
|
||||
Returns:
|
||||
str: prompt string
|
||||
str: prompt string in ChatML format
|
||||
|
||||
"""
|
||||
prompt_parts = []
|
||||
prompt = f"<|im_start|>system\n{init_prompt}<|im_end|>\n" if init_prompt else ""
|
||||
|
||||
if history is None:
|
||||
history = []
|
||||
# (convert tuple format to ChatML)
|
||||
if history:
|
||||
for user_msg, assistant_msg in history:
|
||||
prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
|
||||
prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
|
||||
|
||||
for his_query, his_response in history:
|
||||
prompt_parts.append(f"<|user|> {his_query} <|system|> <bos>{his_response}<eos>")
|
||||
prompt += f"<|im_start|>user\n{query}<|im_end|>\n"
|
||||
prompt += "<|im_start|>assistant\n"
|
||||
|
||||
if query is not None:
|
||||
prompt_parts.append(f"<|user|> {query} <|system|> <bos>")
|
||||
|
||||
return "\n".join(prompt_parts)
|
||||
return prompt
|
||||
|
||||
def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -68,11 +68,10 @@ class SftStrategy(BaseStrategy):
|
|||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
batch = move_to_device(batch, self.device)
|
||||
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
||||
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
|
||||
input_ids, target_ids, loss_mask = batch["input_ids"], batch["target_ids"], batch["loss_mask"]
|
||||
|
||||
ignore_index = -100
|
||||
logits = self.model(input_ids=input_ids, input_mask=attn_mask)["logits"]
|
||||
logits = self.model(input_ids=input_ids)["logits"]
|
||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||
|
||||
loss = F.cross_entropy(
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import matplotlib
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from khaosz.config.model_config import TransformerConfig
|
||||
from khaosz.data.data_util import build_attention_mask, build_loss_mask
|
||||
from khaosz.data.tokenizer import BpeTokenizer
|
||||
from khaosz.model.transformer import Transformer
|
||||
|
||||
|
|
@ -46,14 +45,12 @@ class MultiTurnDataset(Dataset):
|
|||
def __getitem__(self, idx):
|
||||
input_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
||||
target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
||||
loss_mask = build_loss_mask(input_ids, 0, 1)
|
||||
attn_mask = build_attention_mask(input_ids, 2, True)
|
||||
loss_mask = torch.randint(0, 1, (self.max_length,))
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"target_ids": target_ids,
|
||||
"loss_mask": loss_mask,
|
||||
"attn_mask": attn_mask,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ def test_multi_turn_training(base_test_env, multi_turn_dataset):
|
|||
base_test_env["device"],
|
||||
bos_token_id=2,
|
||||
eos_token_id=3,
|
||||
user_token_id=1,
|
||||
multi_turn=True
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue