fix(data): 修复数据加载模块中的拼写错误并优化内存映射加载逻辑
This commit is contained in:
parent
019bfe4e05
commit
1f5cba889b
|
|
@ -4,7 +4,7 @@ from khaosz.data.dataset import (
|
||||||
DpoDataset,
|
DpoDataset,
|
||||||
SftDataset,
|
SftDataset,
|
||||||
PpoDataset,
|
PpoDataset,
|
||||||
MutiSegmentFetcher,
|
MultiSegmentFetcher,
|
||||||
DatasetLoader,
|
DatasetLoader,
|
||||||
load_pkl_files,
|
load_pkl_files,
|
||||||
)
|
)
|
||||||
|
|
@ -18,7 +18,7 @@ __all__ = [
|
||||||
"DpoDataset",
|
"DpoDataset",
|
||||||
"SftDataset",
|
"SftDataset",
|
||||||
"PpoDataset",
|
"PpoDataset",
|
||||||
"MutiSegmentFetcher",
|
"MultiSegmentFetcher",
|
||||||
"DatasetLoader",
|
"DatasetLoader",
|
||||||
"load_pkl_files",
|
"load_pkl_files",
|
||||||
"BpeTokenizer",
|
"BpeTokenizer",
|
||||||
|
|
|
||||||
|
|
@ -1,34 +1,80 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
import torch
|
import torch
|
||||||
import bisect
|
import bisect
|
||||||
import pickle as pkl
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from typing import Callable, List, Dict, Literal, Optional, Union
|
from typing import Callable, List, Dict, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
MutiSeg = Dict[str, List[Tensor]]
|
|
||||||
Seg = Dict[str, Tensor]
|
|
||||||
|
|
||||||
def load_pkl_files(paths: List[str]):
|
Seg = List[Tensor]
|
||||||
segments: MutiSeg = {}
|
MultiSeg = Dict[str, Seg]
|
||||||
total_samples = 0
|
|
||||||
|
|
||||||
for path in paths:
|
|
||||||
with open(path, "rb") as f:
|
def load_mmap_files(root_path: str, shared: bool=True) -> Tuple[MultiSeg, int]:
|
||||||
pkl_file: Seg = pkl.load(f)
|
"""Load memory-mapped binary files as torch tensors.
|
||||||
for key, value in pkl_file.items():
|
|
||||||
if key not in segments:
|
|
||||||
segments[key] = []
|
|
||||||
segments[key].append(value)
|
|
||||||
first_key = list(pkl_file.keys())[0]
|
|
||||||
total_samples += pkl_file[first_key].numel()
|
|
||||||
|
|
||||||
return segments, total_samples
|
Loads configuration from file_mapper.json in the specified directory, then loads
|
||||||
|
corresponding binary files as memory-mapped tensors. Returns tensors grouped by key
|
||||||
|
and total number of elements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root_path: Root directory path containing file_mapper.json and binary files
|
||||||
|
shared: Whether to load tensors in shared mode. If True, tensors can be
|
||||||
|
shared between processes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If file_mapper.json or any binary file in config is missing
|
||||||
|
KeyError: If dtype in config is not in supported DTYPE_MAP
|
||||||
|
json.JSONDecodeError: If config file is not valid JSON
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- MultiSeg: Dictionary of tensors grouped by key, structure: {key: [tensor1, tensor2, ...]}
|
||||||
|
- int: Total number of elements across all tensors
|
||||||
|
"""
|
||||||
|
|
||||||
|
DTYPE_MAP = {
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float64": torch.float64,
|
||||||
|
"int32": torch.int32,
|
||||||
|
"int64": torch.int64,
|
||||||
|
"bool": torch.bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata_list = []
|
||||||
|
mmap_shared_group: MultiSeg = {}
|
||||||
|
|
||||||
|
file_mapper_path = os.path.join(root_path, "file_mapper.json")
|
||||||
|
if not os.path.exists(file_mapper_path):
|
||||||
|
raise FileNotFoundError(f"File mapper not found: {file_mapper_path}")
|
||||||
|
|
||||||
|
with open(file_mapper_path, "r") as f:
|
||||||
|
metadata_list = json.load(f)
|
||||||
|
|
||||||
|
num_samples = sum(metadata["size"] for metadata in metadata_list)
|
||||||
|
|
||||||
|
for metadata in metadata_list:
|
||||||
|
file_path = os.path.join(root_path, metadata["file_name"])
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"Binary data file not found: {file_path}")
|
||||||
|
|
||||||
|
size = metadata["size"]
|
||||||
|
dtype = DTYPE_MAP[metadata["dtype"]]
|
||||||
|
segment_key = metadata["key"]
|
||||||
|
mmap_tensor = torch.from_file(file_path, shared=shared, size=size, dtype=dtype)
|
||||||
|
|
||||||
|
if segment_key not in mmap_shared_group:
|
||||||
|
mmap_shared_group[segment_key] = []
|
||||||
|
|
||||||
|
mmap_shared_group[segment_key].append(mmap_tensor)
|
||||||
|
|
||||||
|
return mmap_shared_group, num_samples
|
||||||
|
|
||||||
|
|
||||||
class BaseSegmentFetcher:
|
class BaseSegmentFetcher:
|
||||||
def __init__(self, segments: List[Tensor]):
|
def __init__(self, segments: Seg):
|
||||||
self.segments = segments
|
self.segments = segments
|
||||||
self.cum_lengths = []
|
self.cum_lengths = []
|
||||||
total = 0
|
total = 0
|
||||||
|
|
@ -58,8 +104,8 @@ class BaseSegmentFetcher:
|
||||||
return torch.cat(result_segments, dim=0)
|
return torch.cat(result_segments, dim=0)
|
||||||
|
|
||||||
|
|
||||||
class MutiSegmentFetcher:
|
class MultiSegmentFetcher:
|
||||||
def __init__(self, muti_segments: MutiSeg):
|
def __init__(self, muti_segments: MultiSeg):
|
||||||
self.muti_keys = list(muti_segments.keys())
|
self.muti_keys = list(muti_segments.keys())
|
||||||
self.muti_fetchers = {
|
self.muti_fetchers = {
|
||||||
key: BaseSegmentFetcher(segments)
|
key: BaseSegmentFetcher(segments)
|
||||||
|
|
@ -82,29 +128,17 @@ class MutiSegmentFetcher:
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(Dataset, ABC):
|
class BaseDataset(Dataset, ABC):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int, share_memory: bool=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.segments: MutiSeg = {}
|
self.segments: MultiSeg = {}
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.total_samples = None
|
self.total_samples = None
|
||||||
|
|
||||||
def save(self, save_path: str):
|
|
||||||
keys = list(self.segments.keys())
|
|
||||||
if not keys:
|
|
||||||
return
|
|
||||||
|
|
||||||
first_item = self.segments[keys[0]]
|
|
||||||
segment_size = len(first_item)
|
|
||||||
|
|
||||||
for i in range(segment_size):
|
|
||||||
formated_segment = {key: self.segments[key][i] for key in keys}
|
|
||||||
pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb"))
|
|
||||||
|
|
||||||
def load(self, load_path: Union[str, List[str]]):
|
def load(self, load_path: Union[str, List[str]]):
|
||||||
paths = [load_path] if isinstance(load_path, str) else load_path
|
paths = [load_path] if isinstance(load_path, str) else load_path
|
||||||
self.segments, self.total_samples = load_pkl_files(paths)
|
self.segments, self.total_samples = load_mmap_files(paths)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def get_index(self, index: int) -> int:
|
def get_index(self, index: int) -> int:
|
||||||
begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1)
|
begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1)
|
||||||
|
|
@ -126,7 +160,7 @@ class BaseDataset(Dataset, ABC):
|
||||||
class SeqDataset(BaseDataset):
|
class SeqDataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
||||||
|
|
@ -144,7 +178,7 @@ class SeqDataset(BaseDataset):
|
||||||
class SftDataset(BaseDataset):
|
class SftDataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
@ -162,7 +196,7 @@ class SftDataset(BaseDataset):
|
||||||
class DpoDataset(BaseDataset):
|
class DpoDataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
@ -181,7 +215,7 @@ class DpoDataset(BaseDataset):
|
||||||
class PpoDataset(BaseDataset):
|
class PpoDataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue