fix(data): 修复数据加载模块中的拼写错误并优化内存映射加载逻辑

This commit is contained in:
ViperEkura 2025-11-28 20:21:53 +08:00
parent 019bfe4e05
commit 1f5cba889b
2 changed files with 77 additions and 43 deletions

View File

@ -4,7 +4,7 @@ from khaosz.data.dataset import (
DpoDataset, DpoDataset,
SftDataset, SftDataset,
PpoDataset, PpoDataset,
MutiSegmentFetcher, MultiSegmentFetcher,
DatasetLoader, DatasetLoader,
load_pkl_files, load_pkl_files,
) )
@ -18,7 +18,7 @@ __all__ = [
"DpoDataset", "DpoDataset",
"SftDataset", "SftDataset",
"PpoDataset", "PpoDataset",
"MutiSegmentFetcher", "MultiSegmentFetcher",
"DatasetLoader", "DatasetLoader",
"load_pkl_files", "load_pkl_files",
"BpeTokenizer", "BpeTokenizer",

View File

@ -1,34 +1,80 @@
import os
import json
import torch import torch
import bisect import bisect
import pickle as pkl
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
from typing import Callable, List, Dict, Literal, Optional, Union from typing import Callable, List, Dict, Literal, Optional, Tuple, Union
MutiSeg = Dict[str, List[Tensor]]
Seg = Dict[str, Tensor]
def load_pkl_files(paths: List[str]): Seg = List[Tensor]
segments: MutiSeg = {} MultiSeg = Dict[str, Seg]
total_samples = 0
for path in paths:
with open(path, "rb") as f: def load_mmap_files(root_path: str, shared: bool=True) -> Tuple[MultiSeg, int]:
pkl_file: Seg = pkl.load(f) """Load memory-mapped binary files as torch tensors.
for key, value in pkl_file.items():
if key not in segments:
segments[key] = []
segments[key].append(value)
first_key = list(pkl_file.keys())[0]
total_samples += pkl_file[first_key].numel()
return segments, total_samples Loads configuration from file_mapper.json in the specified directory, then loads
corresponding binary files as memory-mapped tensors. Returns tensors grouped by key
and total number of elements.
Args:
root_path: Root directory path containing file_mapper.json and binary files
shared: Whether to load tensors in shared mode. If True, tensors can be
shared between processes
Raises:
FileNotFoundError: If file_mapper.json or any binary file in config is missing
KeyError: If dtype in config is not in supported DTYPE_MAP
json.JSONDecodeError: If config file is not valid JSON
Returns:
Tuple containing:
- MultiSeg: Dictionary of tensors grouped by key, structure: {key: [tensor1, tensor2, ...]}
- int: Total number of elements across all tensors
"""
DTYPE_MAP = {
"float32": torch.float32,
"float64": torch.float64,
"int32": torch.int32,
"int64": torch.int64,
"bool": torch.bool,
}
metadata_list = []
mmap_shared_group: MultiSeg = {}
file_mapper_path = os.path.join(root_path, "file_mapper.json")
if not os.path.exists(file_mapper_path):
raise FileNotFoundError(f"File mapper not found: {file_mapper_path}")
with open(file_mapper_path, "r") as f:
metadata_list = json.load(f)
num_samples = sum(metadata["size"] for metadata in metadata_list)
for metadata in metadata_list:
file_path = os.path.join(root_path, metadata["file_name"])
if not os.path.exists(file_path):
raise FileNotFoundError(f"Binary data file not found: {file_path}")
size = metadata["size"]
dtype = DTYPE_MAP[metadata["dtype"]]
segment_key = metadata["key"]
mmap_tensor = torch.from_file(file_path, shared=shared, size=size, dtype=dtype)
if segment_key not in mmap_shared_group:
mmap_shared_group[segment_key] = []
mmap_shared_group[segment_key].append(mmap_tensor)
return mmap_shared_group, num_samples
class BaseSegmentFetcher: class BaseSegmentFetcher:
def __init__(self, segments: List[Tensor]): def __init__(self, segments: Seg):
self.segments = segments self.segments = segments
self.cum_lengths = [] self.cum_lengths = []
total = 0 total = 0
@ -58,8 +104,8 @@ class BaseSegmentFetcher:
return torch.cat(result_segments, dim=0) return torch.cat(result_segments, dim=0)
class MutiSegmentFetcher: class MultiSegmentFetcher:
def __init__(self, muti_segments: MutiSeg): def __init__(self, muti_segments: MultiSeg):
self.muti_keys = list(muti_segments.keys()) self.muti_keys = list(muti_segments.keys())
self.muti_fetchers = { self.muti_fetchers = {
key: BaseSegmentFetcher(segments) key: BaseSegmentFetcher(segments)
@ -82,29 +128,17 @@ class MutiSegmentFetcher:
class BaseDataset(Dataset, ABC): class BaseDataset(Dataset, ABC):
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int, share_memory: bool=False):
super().__init__() super().__init__()
self.segments: MutiSeg = {} self.segments: MultiSeg = {}
self.window_size = window_size self.window_size = window_size
self.stride = stride self.stride = stride
self.total_samples = None self.total_samples = None
def save(self, save_path: str):
keys = list(self.segments.keys())
if not keys:
return
first_item = self.segments[keys[0]]
segment_size = len(first_item)
for i in range(segment_size):
formated_segment = {key: self.segments[key][i] for key in keys}
pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb"))
def load(self, load_path: Union[str, List[str]]): def load(self, load_path: Union[str, List[str]]):
paths = [load_path] if isinstance(load_path, str) else load_path paths = [load_path] if isinstance(load_path, str) else load_path
self.segments, self.total_samples = load_pkl_files(paths) self.segments, self.total_samples = load_mmap_files(paths)
self.fetcher = MutiSegmentFetcher(self.segments) self.fetcher = MultiSegmentFetcher(self.segments)
def get_index(self, index: int) -> int: def get_index(self, index: int) -> int:
begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1) begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1)
@ -126,7 +160,7 @@ class BaseDataset(Dataset, ABC):
class SeqDataset(BaseDataset): class SeqDataset(BaseDataset):
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
self.fetcher = MutiSegmentFetcher(self.segments) self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
@ -144,7 +178,7 @@ class SeqDataset(BaseDataset):
class SftDataset(BaseDataset): class SftDataset(BaseDataset):
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
self.fetcher = MutiSegmentFetcher(self.segments) self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)
@ -162,7 +196,7 @@ class SftDataset(BaseDataset):
class DpoDataset(BaseDataset): class DpoDataset(BaseDataset):
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
self.fetcher = MutiSegmentFetcher(self.segments) self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)
@ -181,7 +215,7 @@ class DpoDataset(BaseDataset):
class PpoDataset(BaseDataset): class PpoDataset(BaseDataset):
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
self.fetcher = MutiSegmentFetcher(self.segments) self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)