fix(khaosz/trainer): 修复数据获取中的索引范围错误和参数传递问题

This commit is contained in:
ViperEkura 2025-10-09 19:53:52 +08:00
parent 68a15005cb
commit 8434c19923
1 changed files with 7 additions and 7 deletions

View File

@ -72,8 +72,9 @@ class BaseSegmentFetcher:
if begin_idx >= end_idx: if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long) return torch.tensor([], dtype=torch.long)
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx - 1) # fix the range index bug
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx - 1) seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = [] result_segments = []
@ -165,7 +166,6 @@ class SeqDataset(BaseDataset):
return {"input_ids": x, "target_ids": y} return {"input_ids": x, "target_ids": y}
class SftDataset(BaseDataset): class SftDataset(BaseDataset):
def __init__( def __init__(
self, self,
@ -182,15 +182,15 @@ class SftDataset(BaseDataset):
self.user_token_id = user_token_id self.user_token_id = user_token_id
self.multi_turn = multi_turn self.multi_turn = multi_turn
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
def __getitem__(self, index): def __getitem__(self, index):
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
end_idx = begin_idx + self.chunk_size end_idx = begin_idx + self.chunk_size
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
# fix the eos_token_id bug(change target_ids to input_ids) # fix the eos_token_id bug(change target_ids to input_ids)
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id) loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)