diff --git a/khaosz/data/__init__.py b/khaosz/data/__init__.py index 10b3f44..9a3c5ae 100644 --- a/khaosz/data/__init__.py +++ b/khaosz/data/__init__.py @@ -8,8 +8,6 @@ from khaosz.data.data_util import ( ResumeableRandomSampler, DatasetLoader, load_pkl_files, - build_attention_mask, - build_loss_mask ) from khaosz.data.tokenizer import BpeTokenizer @@ -24,7 +22,5 @@ __all__ = [ "ResumeableRandomSampler", "DatasetLoader", "load_pkl_files", - "build_attention_mask", - "build_loss_mask", "BpeTokenizer" ] \ No newline at end of file diff --git a/khaosz/data/data_util.py b/khaosz/data/data_util.py index 1a1c317..8e022c3 100644 --- a/khaosz/data/data_util.py +++ b/khaosz/data/data_util.py @@ -4,7 +4,7 @@ import pickle as pkl from abc import ABC, abstractmethod from torch import Tensor from torch.utils.data import Dataset, Sampler -from typing import Callable, List, Dict, Literal, Union +from typing import Callable, List, Dict, Literal, Optional, Union MutiSeg = Dict[str, List[Tensor]] Seg = Dict[str, Tensor] @@ -25,36 +25,6 @@ def load_pkl_files(paths: List[str]): return segments, total_samples -def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor: - seq_len = input_ids.size(0) - turn_id = input_ids.eq(user_token_id).cumsum(dim=-1) - - iq = turn_id.view(seq_len, 1) - ik = turn_id.view(1, seq_len) - - # fix the causual attention mask(iq >= ik condition) - seq_mask = (iq >= ik) if multi_turn else (iq == ik) - attention_mask = torch.tril(seq_mask) - - # fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast - return attention_mask.unsqueeze(0) - -def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor: - token_markers = torch.zeros_like(input_ids, dtype=torch.int8) - - is_bos_token = input_ids.eq(bos_token_id) - is_eos_token = input_ids.eq(eos_token_id) - - # fix the eos_token_id bug(change target_ids to input_ids) - token_markers[is_bos_token] = 1 - token_markers[is_eos_token] = -1 - - cumulative_markers = torch.cumsum(token_markers, dim=-1) - min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values - loss_mask = cumulative_markers - min_cumulative - - return loss_mask.to(dtype=torch.bool) - class BaseSegmentFetcher: def __init__(self, segments: List[Tensor]): @@ -111,11 +81,12 @@ class MutiSegmentFetcher: class BaseDataset(Dataset, ABC): - def __init__(self, chunk_size: int): + def __init__(self, chunk_size: int, step_size: int): super().__init__() self.segments: MutiSeg = {} self.chunk_size = chunk_size - self.total_samples = 0 + self.step_size = step_size + self.total_samples = None def save(self, save_path: str): keys = list(self.segments.keys()) @@ -140,16 +111,15 @@ class BaseDataset(Dataset, ABC): raise NotImplementedError def __len__(self) -> int: - assert self.total_samples // self.chunk_size > 0 - return self.total_samples // self.chunk_size + assert self.total_samples is not None + if self.total_samples < self.chunk_size: + return 0 + return (self.total_samples - self.chunk_size) // self.step_size + 1 class SeqDataset(BaseDataset): - def __init__( - self, - chunk_size, - ): - super().__init__(chunk_size) + def __init__(self, chunk_size: int, step_size: int): + super().__init__(chunk_size, step_size) self.fetcher = MutiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: @@ -167,41 +137,27 @@ class SeqDataset(BaseDataset): class SftDataset(BaseDataset): - def __init__( - self, - chunk_size, - bos_token_id, - eos_token_id, - user_token_id, - multi_turn=False, - ): - super().__init__(chunk_size) + def __init__(self, chunk_size: int, step_size: int): + super().__init__(chunk_size, step_size) self.fetcher = MutiSegmentFetcher(self.segments) - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.user_token_id = user_token_id - self.multi_turn = multi_turn - - def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: - return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") + + def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: + return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index): begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) end_idx = begin_idx + self.chunk_size - x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) - y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long) + x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) + y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long) + loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool) - # fix the eos_token_id bug(change target_ids to input_ids) - loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id) - attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn) - - return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask} + return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask} class DpoDataset(BaseDataset): - def __init__(self, chunk_size: int): - super().__init__(chunk_size) + def __init__(self, chunk_size: int, step_size: int): + super().__init__(chunk_size, step_size) self.fetcher = MutiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: @@ -220,8 +176,8 @@ class DpoDataset(BaseDataset): class PpoDataset(BaseDataset): - def __init__(self, chunk_size: int): - super().__init__(chunk_size) + def __init__(self, chunk_size: int, step_size: int): + super().__init__(chunk_size, step_size) self.fetcher = MutiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: @@ -245,19 +201,16 @@ class DatasetLoader: train_type: Literal["seq", "sft", "dpo"], load_path: Union[str, List[str]], max_len: int, + step_size: Optional[int] = None, **kwargs ) -> BaseDataset: + if step_size is None: + step_size = max_len dataset_router: Dict[str, Callable[[int], BaseDataset]] = { - "seq": lambda max_len: SeqDataset(max_len), - "sft": lambda max_len: SftDataset( - max_len, - bos_token_id=kwargs.get("bos_token_id"), - eos_token_id=kwargs.get("eos_token_id"), - user_token_id=kwargs.get("user_token_id"), - multi_turn=kwargs.get("multi_turn") - ), - "dpo": lambda max_len: DpoDataset(max_len), + "seq": lambda max_len: SeqDataset(max_len, step_size), + "sft": lambda max_len: SftDataset(max_len, step_size), + "dpo": lambda max_len: DpoDataset(max_len, step_size), } dataset = dataset_router[train_type](max_len) dataset.load(load_path) diff --git a/khaosz/data/tokenizer.py b/khaosz/data/tokenizer.py index 65e06b4..935027d 100644 --- a/khaosz/data/tokenizer.py +++ b/khaosz/data/tokenizer.py @@ -8,7 +8,7 @@ from typing import List, Union class BpeTokenizer: def __init__(self, path=None): self._control_tokens = ["", "", ""] - self._special_tokens = ["<|user|>", "<|system|>"] + self._special_tokens = ["<|im_start|>", "<|im_end|>"] model = BPE() tokenizer = Tokenizer(model) tokenizer.normalizer = normalizers.Sequence([ @@ -93,9 +93,7 @@ class BpeTokenizer: @property def stop_ids(self) -> List[int]: - stop_ids = [] - for token in self._control_tokens: - stop_ids.append(self._tokenizer.token_to_id(token)) + stop_ids = self._control_tokens + self._special_tokens return stop_ids @property @@ -108,12 +106,4 @@ class BpeTokenizer: @property def pad_id(self) -> int: - return self._tokenizer.token_to_id("") - - @property - def user_id(self) -> int: - return self._tokenizer.token_to_id("<|user|>") - - @property - def system_id(self) -> int: - return self._tokenizer.token_to_id("<|system|>") \ No newline at end of file + return self._tokenizer.token_to_id("") \ No newline at end of file diff --git a/khaosz/inference/generator.py b/khaosz/inference/generator.py index 297c9d5..cbd404f 100644 --- a/khaosz/inference/generator.py +++ b/khaosz/inference/generator.py @@ -5,30 +5,34 @@ from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheMa from khaosz.config.param_config import ModelParameter -def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str: +def build_prompt( + query: str, + init_prompt: Optional[str] = None, + history: Optional[List[Tuple[str, str]]] = None + ) -> str: """ - Build prompt for query and history + Build prompt in ChatML format for query and history Args: query(str): query string history(Optional[List[Tuple[str, str]]]): history list of query and response Returns: - str: prompt string + str: prompt string in ChatML format """ - prompt_parts = [] + prompt = f"<|im_start|>system\n{init_prompt}<|im_end|>\n" if init_prompt else "" - if history is None: - history = [] + # (convert tuple format to ChatML) + if history: + for user_msg, assistant_msg in history: + prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n" + prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n" - for his_query, his_response in history: - prompt_parts.append(f"<|user|> {his_query} <|system|> {his_response}") - - if query is not None: - prompt_parts.append(f"<|user|> {query} <|system|> ") + prompt += f"<|im_start|>user\n{query}<|im_end|>\n" + prompt += "<|im_start|>assistant\n" - return "\n".join(prompt_parts) + return prompt def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]: """ diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 0771a55..9d79d84 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -68,11 +68,10 @@ class SftStrategy(BaseStrategy): def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) - input_ids, target_ids = batch["input_ids"], batch["target_ids"] - loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"] + input_ids, target_ids, loss_mask = batch["input_ids"], batch["target_ids"], batch["loss_mask"] ignore_index = -100 - logits = self.model(input_ids=input_ids, input_mask=attn_mask)["logits"] + logits = self.model(input_ids=input_ids)["logits"] target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) loss = F.cross_entropy( diff --git a/tests/conftest.py b/tests/conftest.py index 17412c9..4eaa9e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,6 @@ import matplotlib from torch.utils.data import Dataset from khaosz.config.model_config import TransformerConfig -from khaosz.data.data_util import build_attention_mask, build_loss_mask from khaosz.data.tokenizer import BpeTokenizer from khaosz.model.transformer import Transformer @@ -46,14 +45,12 @@ class MultiTurnDataset(Dataset): def __getitem__(self, idx): input_ids = torch.randint(0, self.vocab_size, (self.max_length,)) target_ids = torch.randint(0, self.vocab_size, (self.max_length,)) - loss_mask = build_loss_mask(input_ids, 0, 1) - attn_mask = build_attention_mask(input_ids, 2, True) + loss_mask = torch.randint(0, 1, (self.max_length,)) return { "input_ids": input_ids, "target_ids": target_ids, "loss_mask": loss_mask, - "attn_mask": attn_mask, } diff --git a/tests/test_train_strategy.py b/tests/test_train_strategy.py index da8a8c0..61ce506 100644 --- a/tests/test_train_strategy.py +++ b/tests/test_train_strategy.py @@ -31,7 +31,6 @@ def test_multi_turn_training(base_test_env, multi_turn_dataset): base_test_env["device"], bos_token_id=2, eos_token_id=3, - user_token_id=1, multi_turn=True ) diff --git a/train.py b/train.py index c230f6d..cd6e2ff 100644 --- a/train.py +++ b/train.py @@ -61,7 +61,6 @@ def train( "bos_token_id": parameter.tokenizer.bos_id, "eos_token_id": parameter.tokenizer.eos_id, "pad_token_id": parameter.tokenizer.pad_id, - "user_token_id":parameter.tokenizer.user_id, } strategy = StrategyFactory.load(