fix(data): 修复数据加载模块中的拼写错误并优化内存映射加载逻辑

This commit is contained in:
ViperEkura 2025-11-28 20:21:53 +08:00
parent 019bfe4e05
commit 1f5cba889b
2 changed files with 77 additions and 43 deletions

View File

@ -4,7 +4,7 @@ from khaosz.data.dataset import (
DpoDataset,
SftDataset,
PpoDataset,
MutiSegmentFetcher,
MultiSegmentFetcher,
DatasetLoader,
load_pkl_files,
)
@ -18,7 +18,7 @@ __all__ = [
"DpoDataset",
"SftDataset",
"PpoDataset",
"MutiSegmentFetcher",
"MultiSegmentFetcher",
"DatasetLoader",
"load_pkl_files",
"BpeTokenizer",

View File

@ -1,34 +1,80 @@
import os
import json
import torch
import bisect
import pickle as pkl
from abc import ABC, abstractmethod
from torch import Tensor
from torch.utils.data import Dataset
from typing import Callable, List, Dict, Literal, Optional, Union
from typing import Callable, List, Dict, Literal, Optional, Tuple, Union
MutiSeg = Dict[str, List[Tensor]]
Seg = Dict[str, Tensor]
Seg = List[Tensor]
MultiSeg = Dict[str, Seg]
def load_pkl_files(paths: List[str]):
segments: MutiSeg = {}
total_samples = 0
for path in paths:
with open(path, "rb") as f:
pkl_file: Seg = pkl.load(f)
for key, value in pkl_file.items():
if key not in segments:
segments[key] = []
segments[key].append(value)
first_key = list(pkl_file.keys())[0]
total_samples += pkl_file[first_key].numel()
def load_mmap_files(root_path: str, shared: bool=True) -> Tuple[MultiSeg, int]:
"""Load memory-mapped binary files as torch tensors.
return segments, total_samples
Loads configuration from file_mapper.json in the specified directory, then loads
corresponding binary files as memory-mapped tensors. Returns tensors grouped by key
and total number of elements.
Args:
root_path: Root directory path containing file_mapper.json and binary files
shared: Whether to load tensors in shared mode. If True, tensors can be
shared between processes
Raises:
FileNotFoundError: If file_mapper.json or any binary file in config is missing
KeyError: If dtype in config is not in supported DTYPE_MAP
json.JSONDecodeError: If config file is not valid JSON
Returns:
Tuple containing:
- MultiSeg: Dictionary of tensors grouped by key, structure: {key: [tensor1, tensor2, ...]}
- int: Total number of elements across all tensors
"""
DTYPE_MAP = {
"float32": torch.float32,
"float64": torch.float64,
"int32": torch.int32,
"int64": torch.int64,
"bool": torch.bool,
}
metadata_list = []
mmap_shared_group: MultiSeg = {}
file_mapper_path = os.path.join(root_path, "file_mapper.json")
if not os.path.exists(file_mapper_path):
raise FileNotFoundError(f"File mapper not found: {file_mapper_path}")
with open(file_mapper_path, "r") as f:
metadata_list = json.load(f)
num_samples = sum(metadata["size"] for metadata in metadata_list)
for metadata in metadata_list:
file_path = os.path.join(root_path, metadata["file_name"])
if not os.path.exists(file_path):
raise FileNotFoundError(f"Binary data file not found: {file_path}")
size = metadata["size"]
dtype = DTYPE_MAP[metadata["dtype"]]
segment_key = metadata["key"]
mmap_tensor = torch.from_file(file_path, shared=shared, size=size, dtype=dtype)
if segment_key not in mmap_shared_group:
mmap_shared_group[segment_key] = []
mmap_shared_group[segment_key].append(mmap_tensor)
return mmap_shared_group, num_samples
class BaseSegmentFetcher:
def __init__(self, segments: List[Tensor]):
def __init__(self, segments: Seg):
self.segments = segments
self.cum_lengths = []
total = 0
@ -58,8 +104,8 @@ class BaseSegmentFetcher:
return torch.cat(result_segments, dim=0)
class MutiSegmentFetcher:
def __init__(self, muti_segments: MutiSeg):
class MultiSegmentFetcher:
def __init__(self, muti_segments: MultiSeg):
self.muti_keys = list(muti_segments.keys())
self.muti_fetchers = {
key: BaseSegmentFetcher(segments)
@ -82,29 +128,17 @@ class MutiSegmentFetcher:
class BaseDataset(Dataset, ABC):
def __init__(self, window_size: int, stride: int):
def __init__(self, window_size: int, stride: int, share_memory: bool=False):
super().__init__()
self.segments: MutiSeg = {}
self.segments: MultiSeg = {}
self.window_size = window_size
self.stride = stride
self.total_samples = None
def save(self, save_path: str):
keys = list(self.segments.keys())
if not keys:
return
first_item = self.segments[keys[0]]
segment_size = len(first_item)
for i in range(segment_size):
formated_segment = {key: self.segments[key][i] for key in keys}
pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb"))
def load(self, load_path: Union[str, List[str]]):
paths = [load_path] if isinstance(load_path, str) else load_path
self.segments, self.total_samples = load_pkl_files(paths)
self.fetcher = MutiSegmentFetcher(self.segments)
self.segments, self.total_samples = load_mmap_files(paths)
self.fetcher = MultiSegmentFetcher(self.segments)
def get_index(self, index: int) -> int:
begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1)
@ -126,7 +160,7 @@ class BaseDataset(Dataset, ABC):
class SeqDataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MutiSegmentFetcher(self.segments)
self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
@ -144,7 +178,7 @@ class SeqDataset(BaseDataset):
class SftDataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MutiSegmentFetcher(self.segments)
self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
@ -162,7 +196,7 @@ class SftDataset(BaseDataset):
class DpoDataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MutiSegmentFetcher(self.segments)
self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
@ -181,7 +215,7 @@ class DpoDataset(BaseDataset):
class PpoDataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MutiSegmentFetcher(self.segments)
self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)