268 lines
9.6 KiB
Python
268 lines
9.6 KiB
Python
import os
|
|
import json
|
|
import torch
|
|
import bisect
|
|
|
|
from abc import ABC, abstractmethod
|
|
from torch import Tensor
|
|
from torch.utils.data import Dataset
|
|
from typing import Callable, List, Dict, Literal, Optional, Tuple, Union
|
|
|
|
Seg = List[Tensor]
|
|
MultiSeg = Dict[str, Seg]
|
|
|
|
|
|
def load_mmap_files(root_path: str, shared: bool=True) -> Tuple[MultiSeg, int]:
|
|
"""Load memory-mapped binary files as torch tensors.
|
|
|
|
Loads configuration from file_mapper.json in the specified directory, then loads
|
|
corresponding binary files as memory-mapped tensors. Returns tensors grouped by key
|
|
and total number of elements.
|
|
|
|
json metadata like this:
|
|
|
|
```
|
|
[
|
|
{
|
|
"file_name": "file1.bin",
|
|
"size": 1000,
|
|
"dtype": "float32",
|
|
"key": "key1"
|
|
},
|
|
...
|
|
]
|
|
```
|
|
|
|
Args:
|
|
root_path: Root directory path containing file_mapper.json and binary files
|
|
shared: Whether to load tensors in shared mode. If True, tensors can be
|
|
shared between processes
|
|
|
|
Raises:
|
|
FileNotFoundError: If file_mapper.json or any binary file in config is missing
|
|
KeyError: If dtype in config is not in supported DTYPE_MAP
|
|
json.JSONDecodeError: If config file is not valid JSON
|
|
|
|
Returns:
|
|
Tuple containing:
|
|
- MultiSeg: Dictionary of tensors grouped by key, structure: {key: [tensor1, tensor2, ...]}
|
|
- int: Total number of elements across all tensors
|
|
"""
|
|
|
|
DTYPE_MAP = {
|
|
"float32": torch.float32,
|
|
"float64": torch.float64,
|
|
"int32": torch.int32,
|
|
"int64": torch.int64,
|
|
"bool": torch.bool,
|
|
}
|
|
|
|
metadata_list = []
|
|
mmap_shared_group: MultiSeg = {}
|
|
|
|
file_mapper_path = os.path.join(root_path, "file_mapper.json")
|
|
if not os.path.exists(file_mapper_path):
|
|
raise FileNotFoundError(f"File mapper not found: {file_mapper_path}")
|
|
|
|
with open(file_mapper_path, "r") as f:
|
|
metadata_list = json.load(f)
|
|
|
|
for metadata in metadata_list:
|
|
file_path = os.path.join(root_path, metadata["file_name"])
|
|
if not os.path.exists(file_path):
|
|
raise FileNotFoundError(f"Binary data file not found: {file_path}")
|
|
|
|
size = metadata["size"]
|
|
dtype = DTYPE_MAP[metadata["dtype"]]
|
|
segment_key = metadata["key"]
|
|
mmap_tensor = torch.from_file(file_path, shared=shared, size=size, dtype=dtype)
|
|
|
|
if segment_key not in mmap_shared_group:
|
|
mmap_shared_group[segment_key] = []
|
|
|
|
mmap_shared_group[segment_key].append(mmap_tensor)
|
|
|
|
num_samples = sum(metadata["size"] for metadata in metadata_list
|
|
if segment_key == metadata["key"])
|
|
|
|
return mmap_shared_group, num_samples
|
|
|
|
|
|
class BaseSegmentFetcher:
|
|
def __init__(self, segments: Seg):
|
|
self.segments = segments
|
|
self.cum_lengths = []
|
|
total = 0
|
|
for seg in segments:
|
|
total += len(seg)
|
|
self.cum_lengths.append(total)
|
|
self.total_length = total if segments else 0
|
|
|
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
|
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
|
|
raise ValueError("begin_idx or end_idx out of bounds")
|
|
if begin_idx >= end_idx:
|
|
return torch.tensor([], dtype=torch.long)
|
|
|
|
# fix the range index bug
|
|
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
|
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
|
|
|
result_segments = []
|
|
|
|
for i in range(seg_start_idx, seg_end_idx + 1):
|
|
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
|
start = max(begin_idx - prev_cum, 0)
|
|
end = min(end_idx - prev_cum, len(self.segments[i]))
|
|
result_segments.append(self.segments[i][start:end])
|
|
|
|
return torch.cat(result_segments, dim=0)
|
|
|
|
|
|
class MultiSegmentFetcher:
|
|
def __init__(self, muti_segments: MultiSeg):
|
|
self.muti_keys = list(muti_segments.keys())
|
|
self.muti_fetchers = {
|
|
key: BaseSegmentFetcher(segments)
|
|
for key, segments in muti_segments.items()
|
|
}
|
|
|
|
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]:
|
|
fetch_dict = {}
|
|
keys = [keys] if isinstance(keys, str) else keys
|
|
|
|
for key in keys:
|
|
fetcher = self.muti_fetchers[key]
|
|
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
|
fetch_dict[key] = fetch_tensor
|
|
|
|
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
|
|
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]:
|
|
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
|
|
|
|
|
class BaseDataset(Dataset, ABC):
|
|
def __init__(self, window_size: int, stride: int):
|
|
super().__init__()
|
|
self.segments: MultiSeg = {}
|
|
self.window_size = window_size
|
|
self.stride = stride
|
|
self.total_samples = None
|
|
|
|
def load(self, load_path: str):
|
|
self.segments, self.total_samples = load_mmap_files(load_path)
|
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
|
|
|
def get_index(self, index: int) -> int:
|
|
begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1)
|
|
end_idx = begin_idx + self.window_size
|
|
|
|
return begin_idx, end_idx
|
|
|
|
@abstractmethod
|
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
|
raise NotImplementedError
|
|
|
|
def __len__(self) -> int:
|
|
assert self.total_samples is not None
|
|
if self.total_samples <= self.window_size:
|
|
return 0
|
|
return self.total_samples // self.stride + 1
|
|
|
|
|
|
class SeqDataset(BaseDataset):
|
|
def __init__(self, window_size: int, stride: int):
|
|
super().__init__(window_size, stride)
|
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
|
|
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
|
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
|
|
|
def __getitem__(self, index):
|
|
# fix the range index bug
|
|
begin_idx, end_idx = self.get_index(index)
|
|
|
|
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
|
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
|
|
|
return {"input_ids": x, "target_ids": y}
|
|
|
|
|
|
class SftDataset(BaseDataset):
|
|
def __init__(self, window_size: int, stride: int):
|
|
super().__init__(window_size, stride)
|
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
|
|
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
|
|
|
def __getitem__(self, index):
|
|
begin_idx, end_idx = self.get_index(index)
|
|
|
|
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
|
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
|
|
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool)
|
|
|
|
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
|
|
|
|
|
class DpoDataset(BaseDataset):
|
|
def __init__(self, window_size: int, stride: int):
|
|
super().__init__(window_size, stride)
|
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
|
|
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
|
|
|
def __getitem__(self, index: int):
|
|
begin_idx, end_idx = self.get_index(index)
|
|
|
|
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
|
|
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
|
|
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool)
|
|
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool)
|
|
|
|
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
|
|
|
|
|
class PpoDataset(BaseDataset):
|
|
def __init__(self, window_size: int, stride: int):
|
|
super().__init__(window_size, stride)
|
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
|
|
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
|
|
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
|
begin_idx, end_idx = self.get_index(index)
|
|
|
|
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
|
|
actions = self._fetch_data(begin_idx, end_idx, "actions"),
|
|
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs"),
|
|
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
|
|
|
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
|
|
|
|
|
|
class DatasetLoader:
|
|
@staticmethod
|
|
def load(
|
|
train_type: Literal["seq", "sft", "dpo"],
|
|
load_path: str,
|
|
window_size: int,
|
|
stride: Optional[int] = None,
|
|
) -> BaseDataset:
|
|
if stride is None:
|
|
stride = window_size
|
|
|
|
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
|
|
"seq": lambda window_size: SeqDataset(window_size, stride),
|
|
"sft": lambda window_size: SftDataset(window_size, stride),
|
|
"dpo": lambda window_size: DpoDataset(window_size, stride),
|
|
}
|
|
dataset = dataset_router[train_type](window_size)
|
|
dataset.load(load_path)
|
|
|
|
return dataset
|