feat(data, inference): 使用chatML格式
This commit is contained in:
parent
38b2725cd1
commit
d94fc5a87a
|
|
@ -8,8 +8,6 @@ from khaosz.data.data_util import (
|
||||||
ResumeableRandomSampler,
|
ResumeableRandomSampler,
|
||||||
DatasetLoader,
|
DatasetLoader,
|
||||||
load_pkl_files,
|
load_pkl_files,
|
||||||
build_attention_mask,
|
|
||||||
build_loss_mask
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from khaosz.data.tokenizer import BpeTokenizer
|
from khaosz.data.tokenizer import BpeTokenizer
|
||||||
|
|
@ -24,7 +22,5 @@ __all__ = [
|
||||||
"ResumeableRandomSampler",
|
"ResumeableRandomSampler",
|
||||||
"DatasetLoader",
|
"DatasetLoader",
|
||||||
"load_pkl_files",
|
"load_pkl_files",
|
||||||
"build_attention_mask",
|
|
||||||
"build_loss_mask",
|
|
||||||
"BpeTokenizer"
|
"BpeTokenizer"
|
||||||
]
|
]
|
||||||
|
|
@ -4,7 +4,7 @@ import pickle as pkl
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data import Dataset, Sampler
|
from torch.utils.data import Dataset, Sampler
|
||||||
from typing import Callable, List, Dict, Literal, Union
|
from typing import Callable, List, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
MutiSeg = Dict[str, List[Tensor]]
|
MutiSeg = Dict[str, List[Tensor]]
|
||||||
Seg = Dict[str, Tensor]
|
Seg = Dict[str, Tensor]
|
||||||
|
|
@ -25,36 +25,6 @@ def load_pkl_files(paths: List[str]):
|
||||||
|
|
||||||
return segments, total_samples
|
return segments, total_samples
|
||||||
|
|
||||||
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
|
|
||||||
seq_len = input_ids.size(0)
|
|
||||||
turn_id = input_ids.eq(user_token_id).cumsum(dim=-1)
|
|
||||||
|
|
||||||
iq = turn_id.view(seq_len, 1)
|
|
||||||
ik = turn_id.view(1, seq_len)
|
|
||||||
|
|
||||||
# fix the causual attention mask(iq >= ik condition)
|
|
||||||
seq_mask = (iq >= ik) if multi_turn else (iq == ik)
|
|
||||||
attention_mask = torch.tril(seq_mask)
|
|
||||||
|
|
||||||
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
|
|
||||||
return attention_mask.unsqueeze(0)
|
|
||||||
|
|
||||||
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
|
|
||||||
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
|
||||||
|
|
||||||
is_bos_token = input_ids.eq(bos_token_id)
|
|
||||||
is_eos_token = input_ids.eq(eos_token_id)
|
|
||||||
|
|
||||||
# fix the eos_token_id bug(change target_ids to input_ids)
|
|
||||||
token_markers[is_bos_token] = 1
|
|
||||||
token_markers[is_eos_token] = -1
|
|
||||||
|
|
||||||
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
|
||||||
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
|
||||||
loss_mask = cumulative_markers - min_cumulative
|
|
||||||
|
|
||||||
return loss_mask.to(dtype=torch.bool)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSegmentFetcher:
|
class BaseSegmentFetcher:
|
||||||
def __init__(self, segments: List[Tensor]):
|
def __init__(self, segments: List[Tensor]):
|
||||||
|
|
@ -111,11 +81,12 @@ class MutiSegmentFetcher:
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(Dataset, ABC):
|
class BaseDataset(Dataset, ABC):
|
||||||
def __init__(self, chunk_size: int):
|
def __init__(self, chunk_size: int, step_size: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.segments: MutiSeg = {}
|
self.segments: MutiSeg = {}
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
self.total_samples = 0
|
self.step_size = step_size
|
||||||
|
self.total_samples = None
|
||||||
|
|
||||||
def save(self, save_path: str):
|
def save(self, save_path: str):
|
||||||
keys = list(self.segments.keys())
|
keys = list(self.segments.keys())
|
||||||
|
|
@ -140,16 +111,15 @@ class BaseDataset(Dataset, ABC):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
assert self.total_samples // self.chunk_size > 0
|
assert self.total_samples is not None
|
||||||
return self.total_samples // self.chunk_size
|
if self.total_samples < self.chunk_size:
|
||||||
|
return 0
|
||||||
|
return (self.total_samples - self.chunk_size) // self.step_size + 1
|
||||||
|
|
||||||
|
|
||||||
class SeqDataset(BaseDataset):
|
class SeqDataset(BaseDataset):
|
||||||
def __init__(
|
def __init__(self, chunk_size: int, step_size: int):
|
||||||
self,
|
super().__init__(chunk_size, step_size)
|
||||||
chunk_size,
|
|
||||||
):
|
|
||||||
super().__init__(chunk_size)
|
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
|
|
@ -167,41 +137,27 @@ class SeqDataset(BaseDataset):
|
||||||
|
|
||||||
|
|
||||||
class SftDataset(BaseDataset):
|
class SftDataset(BaseDataset):
|
||||||
def __init__(
|
def __init__(self, chunk_size: int, step_size: int):
|
||||||
self,
|
super().__init__(chunk_size, step_size)
|
||||||
chunk_size,
|
|
||||||
bos_token_id,
|
|
||||||
eos_token_id,
|
|
||||||
user_token_id,
|
|
||||||
multi_turn=False,
|
|
||||||
):
|
|
||||||
super().__init__(chunk_size)
|
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
self.bos_token_id = bos_token_id
|
|
||||||
self.eos_token_id = eos_token_id
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
self.user_token_id = user_token_id
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
self.multi_turn = multi_turn
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
||||||
end_idx = begin_idx + self.chunk_size
|
end_idx = begin_idx + self.chunk_size
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
|
||||||
|
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool)
|
||||||
|
|
||||||
# fix the eos_token_id bug(change target_ids to input_ids)
|
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
||||||
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)
|
|
||||||
attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn)
|
|
||||||
|
|
||||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}
|
|
||||||
|
|
||||||
|
|
||||||
class DpoDataset(BaseDataset):
|
class DpoDataset(BaseDataset):
|
||||||
def __init__(self, chunk_size: int):
|
def __init__(self, chunk_size: int, step_size: int):
|
||||||
super().__init__(chunk_size)
|
super().__init__(chunk_size, step_size)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
|
|
@ -220,8 +176,8 @@ class DpoDataset(BaseDataset):
|
||||||
|
|
||||||
|
|
||||||
class PpoDataset(BaseDataset):
|
class PpoDataset(BaseDataset):
|
||||||
def __init__(self, chunk_size: int):
|
def __init__(self, chunk_size: int, step_size: int):
|
||||||
super().__init__(chunk_size)
|
super().__init__(chunk_size, step_size)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
|
|
@ -245,19 +201,16 @@ class DatasetLoader:
|
||||||
train_type: Literal["seq", "sft", "dpo"],
|
train_type: Literal["seq", "sft", "dpo"],
|
||||||
load_path: Union[str, List[str]],
|
load_path: Union[str, List[str]],
|
||||||
max_len: int,
|
max_len: int,
|
||||||
|
step_size: Optional[int] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> BaseDataset:
|
) -> BaseDataset:
|
||||||
|
if step_size is None:
|
||||||
|
step_size = max_len
|
||||||
|
|
||||||
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
|
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
|
||||||
"seq": lambda max_len: SeqDataset(max_len),
|
"seq": lambda max_len: SeqDataset(max_len, step_size),
|
||||||
"sft": lambda max_len: SftDataset(
|
"sft": lambda max_len: SftDataset(max_len, step_size),
|
||||||
max_len,
|
"dpo": lambda max_len: DpoDataset(max_len, step_size),
|
||||||
bos_token_id=kwargs.get("bos_token_id"),
|
|
||||||
eos_token_id=kwargs.get("eos_token_id"),
|
|
||||||
user_token_id=kwargs.get("user_token_id"),
|
|
||||||
multi_turn=kwargs.get("multi_turn")
|
|
||||||
),
|
|
||||||
"dpo": lambda max_len: DpoDataset(max_len),
|
|
||||||
}
|
}
|
||||||
dataset = dataset_router[train_type](max_len)
|
dataset = dataset_router[train_type](max_len)
|
||||||
dataset.load(load_path)
|
dataset.load(load_path)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from typing import List, Union
|
||||||
class BpeTokenizer:
|
class BpeTokenizer:
|
||||||
def __init__(self, path=None):
|
def __init__(self, path=None):
|
||||||
self._control_tokens = ["<bos>", "<eos>", "<pad>"]
|
self._control_tokens = ["<bos>", "<eos>", "<pad>"]
|
||||||
self._special_tokens = ["<|user|>", "<|system|>"]
|
self._special_tokens = ["<|im_start|>", "<|im_end|>"]
|
||||||
model = BPE()
|
model = BPE()
|
||||||
tokenizer = Tokenizer(model)
|
tokenizer = Tokenizer(model)
|
||||||
tokenizer.normalizer = normalizers.Sequence([
|
tokenizer.normalizer = normalizers.Sequence([
|
||||||
|
|
@ -93,9 +93,7 @@ class BpeTokenizer:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stop_ids(self) -> List[int]:
|
def stop_ids(self) -> List[int]:
|
||||||
stop_ids = []
|
stop_ids = self._control_tokens + self._special_tokens
|
||||||
for token in self._control_tokens:
|
|
||||||
stop_ids.append(self._tokenizer.token_to_id(token))
|
|
||||||
return stop_ids
|
return stop_ids
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -108,12 +106,4 @@ class BpeTokenizer:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pad_id(self) -> int:
|
def pad_id(self) -> int:
|
||||||
return self._tokenizer.token_to_id("<pad>")
|
return self._tokenizer.token_to_id("<pad>")
|
||||||
|
|
||||||
@property
|
|
||||||
def user_id(self) -> int:
|
|
||||||
return self._tokenizer.token_to_id("<|user|>")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def system_id(self) -> int:
|
|
||||||
return self._tokenizer.token_to_id("<|system|>")
|
|
||||||
|
|
@ -5,30 +5,34 @@ from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheMa
|
||||||
from khaosz.config.param_config import ModelParameter
|
from khaosz.config.param_config import ModelParameter
|
||||||
|
|
||||||
|
|
||||||
def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str:
|
def build_prompt(
|
||||||
|
query: str,
|
||||||
|
init_prompt: Optional[str] = None,
|
||||||
|
history: Optional[List[Tuple[str, str]]] = None
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Build prompt for query and history
|
Build prompt in ChatML format for query and history
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query(str): query string
|
query(str): query string
|
||||||
history(Optional[List[Tuple[str, str]]]): history list of query and response
|
history(Optional[List[Tuple[str, str]]]): history list of query and response
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: prompt string
|
str: prompt string in ChatML format
|
||||||
|
|
||||||
"""
|
"""
|
||||||
prompt_parts = []
|
prompt = f"<|im_start|>system\n{init_prompt}<|im_end|>\n" if init_prompt else ""
|
||||||
|
|
||||||
if history is None:
|
# (convert tuple format to ChatML)
|
||||||
history = []
|
if history:
|
||||||
|
for user_msg, assistant_msg in history:
|
||||||
|
prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
|
||||||
|
prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
|
||||||
|
|
||||||
for his_query, his_response in history:
|
prompt += f"<|im_start|>user\n{query}<|im_end|>\n"
|
||||||
prompt_parts.append(f"<|user|> {his_query} <|system|> <bos>{his_response}<eos>")
|
prompt += "<|im_start|>assistant\n"
|
||||||
|
|
||||||
if query is not None:
|
|
||||||
prompt_parts.append(f"<|user|> {query} <|system|> <bos>")
|
|
||||||
|
|
||||||
return "\n".join(prompt_parts)
|
return prompt
|
||||||
|
|
||||||
def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]:
|
def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -68,11 +68,10 @@ class SftStrategy(BaseStrategy):
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
batch = move_to_device(batch, self.device)
|
batch = move_to_device(batch, self.device)
|
||||||
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
input_ids, target_ids, loss_mask = batch["input_ids"], batch["target_ids"], batch["loss_mask"]
|
||||||
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
|
|
||||||
|
|
||||||
ignore_index = -100
|
ignore_index = -100
|
||||||
logits = self.model(input_ids=input_ids, input_mask=attn_mask)["logits"]
|
logits = self.model(input_ids=input_ids)["logits"]
|
||||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ import matplotlib
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from khaosz.config.model_config import TransformerConfig
|
from khaosz.config.model_config import TransformerConfig
|
||||||
from khaosz.data.data_util import build_attention_mask, build_loss_mask
|
|
||||||
from khaosz.data.tokenizer import BpeTokenizer
|
from khaosz.data.tokenizer import BpeTokenizer
|
||||||
from khaosz.model.transformer import Transformer
|
from khaosz.model.transformer import Transformer
|
||||||
|
|
||||||
|
|
@ -46,14 +45,12 @@ class MultiTurnDataset(Dataset):
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
input_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
input_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
||||||
target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
||||||
loss_mask = build_loss_mask(input_ids, 0, 1)
|
loss_mask = torch.randint(0, 1, (self.max_length,))
|
||||||
attn_mask = build_attention_mask(input_ids, 2, True)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"target_ids": target_ids,
|
"target_ids": target_ids,
|
||||||
"loss_mask": loss_mask,
|
"loss_mask": loss_mask,
|
||||||
"attn_mask": attn_mask,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,6 @@ def test_multi_turn_training(base_test_env, multi_turn_dataset):
|
||||||
base_test_env["device"],
|
base_test_env["device"],
|
||||||
bos_token_id=2,
|
bos_token_id=2,
|
||||||
eos_token_id=3,
|
eos_token_id=3,
|
||||||
user_token_id=1,
|
|
||||||
multi_turn=True
|
multi_turn=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
1
train.py
1
train.py
|
|
@ -61,7 +61,6 @@ def train(
|
||||||
"bos_token_id": parameter.tokenizer.bos_id,
|
"bos_token_id": parameter.tokenizer.bos_id,
|
||||||
"eos_token_id": parameter.tokenizer.eos_id,
|
"eos_token_id": parameter.tokenizer.eos_id,
|
||||||
"pad_token_id": parameter.tokenizer.pad_id,
|
"pad_token_id": parameter.tokenizer.pad_id,
|
||||||
"user_token_id":parameter.tokenizer.user_id,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
strategy = StrategyFactory.load(
|
strategy = StrategyFactory.load(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue