fix(khaosz/trainer): 修复数据获取中的索引范围错误和参数传递问题
This commit is contained in:
parent
68a15005cb
commit
8434c19923
|
|
@ -72,8 +72,9 @@ class BaseSegmentFetcher:
|
||||||
if begin_idx >= end_idx:
|
if begin_idx >= end_idx:
|
||||||
return torch.tensor([], dtype=torch.long)
|
return torch.tensor([], dtype=torch.long)
|
||||||
|
|
||||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx - 1)
|
# fix the range index bug
|
||||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx - 1)
|
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
||||||
|
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
||||||
|
|
||||||
result_segments = []
|
result_segments = []
|
||||||
|
|
||||||
|
|
@ -165,7 +166,6 @@ class SeqDataset(BaseDataset):
|
||||||
return {"input_ids": x, "target_ids": y}
|
return {"input_ids": x, "target_ids": y}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SftDataset(BaseDataset):
|
class SftDataset(BaseDataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -182,15 +182,15 @@ class SftDataset(BaseDataset):
|
||||||
self.user_token_id = user_token_id
|
self.user_token_id = user_token_id
|
||||||
self.multi_turn = multi_turn
|
self.multi_turn = multi_turn
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
||||||
end_idx = begin_idx + self.chunk_size
|
end_idx = begin_idx + self.chunk_size
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
|
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
||||||
|
|
||||||
# fix the eos_token_id bug(change target_ids to input_ids)
|
# fix the eos_token_id bug(change target_ids to input_ids)
|
||||||
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)
|
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue