From 8434c199238851266a1bd882ec151313d2f8ca6c Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 9 Oct 2025 19:53:52 +0800 Subject: [PATCH] =?UTF-8?q?fix(khaosz/trainer):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E8=8E=B7=E5=8F=96=E4=B8=AD=E7=9A=84=E7=B4=A2?= =?UTF-8?q?=E5=BC=95=E8=8C=83=E5=9B=B4=E9=94=99=E8=AF=AF=E5=92=8C=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E4=BC=A0=E9=80=92=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/data_util.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/khaosz/trainer/data_util.py b/khaosz/trainer/data_util.py index 5c229de..7b6570e 100644 --- a/khaosz/trainer/data_util.py +++ b/khaosz/trainer/data_util.py @@ -72,8 +72,9 @@ class BaseSegmentFetcher: if begin_idx >= end_idx: return torch.tensor([], dtype=torch.long) - seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx - 1) - seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx - 1) + # fix the range index bug + seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx) + seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx) result_segments = [] @@ -165,7 +166,6 @@ class SeqDataset(BaseDataset): return {"input_ids": x, "target_ids": y} - class SftDataset(BaseDataset): def __init__( self, @@ -182,15 +182,15 @@ class SftDataset(BaseDataset): self.user_token_id = user_token_id self.multi_turn = multi_turn - def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: - return self.fetcher.key_fetch(begin_idx, end_idx, key) + def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: + return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") def __getitem__(self, index): begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) end_idx = begin_idx + self.chunk_size - x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) - y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long) + x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) + y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long) # fix the eos_token_id bug(change target_ids to input_ids) loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)