refactor(data): 将内存映射文件加载逻辑移至独立的 MmapFileHander 类

This commit is contained in:
ViperEkura 2025-12-15 09:12:42 +08:00
parent d882f65579
commit 701fb9bf78
2 changed files with 95 additions and 80 deletions

View File

@ -1,93 +1,16 @@
import os
import json
import torch
import bisect
from abc import ABC, abstractmethod
from torch import Tensor
from torch.utils.data import Dataset
from typing import Callable, List, Dict, Literal, Optional, Tuple, Union
from khaosz.data.mmap import MmapFileHander
from typing import Callable, List, Dict, Literal, Optional, Union
Seg = List[Tensor]
MultiSeg = Dict[str, Seg]
def load_mmap_files(root_path: str, shared: bool=True) -> Tuple[MultiSeg, int]:
"""Load memory-mapped binary files as torch tensors.
Loads configuration from file_mapper.json in the specified directory, then loads
corresponding binary files as memory-mapped tensors. Returns tensors grouped by key
and total number of elements.
json metadata like this:
```
[
{
"file_name": "file1.bin",
"size": 1000,
"dtype": "float32",
"key": "key1"
},
...
]
```
Args:
root_path: Root directory path containing file_mapper.json and binary files
shared: Whether to load tensors in shared mode. If True, tensors can be
shared between processes
Raises:
FileNotFoundError: If file_mapper.json or any binary file in config is missing
KeyError: If dtype in config is not in supported DTYPE_MAP
json.JSONDecodeError: If config file is not valid JSON
Returns:
Tuple containing:
- MultiSeg: Dictionary of tensors grouped by key, structure: {key: [tensor1, tensor2, ...]}
- int: Total number of elements across all tensors
"""
DTYPE_MAP = {
"float32": torch.float32,
"float64": torch.float64,
"int32": torch.int32,
"int64": torch.int64,
"bool": torch.bool,
}
metadata_list = []
mmap_shared_group: MultiSeg = {}
file_mapper_path = os.path.join(root_path, "file_mapper.json")
if not os.path.exists(file_mapper_path):
raise FileNotFoundError(f"File mapper not found: {file_mapper_path}")
with open(file_mapper_path, "r") as f:
metadata_list = json.load(f)
for metadata in metadata_list:
file_path = os.path.join(root_path, metadata["file_name"])
if not os.path.exists(file_path):
raise FileNotFoundError(f"Binary data file not found: {file_path}")
size = metadata["size"]
dtype = DTYPE_MAP[metadata["dtype"]]
segment_key = metadata["key"]
mmap_tensor = torch.from_file(file_path, shared=shared, size=size, dtype=dtype)
if segment_key not in mmap_shared_group:
mmap_shared_group[segment_key] = []
mmap_shared_group[segment_key].append(mmap_tensor)
num_samples = sum(metadata["size"] for metadata in metadata_list
if segment_key == metadata["key"])
return mmap_shared_group, num_samples
class BaseSegmentFetcher:
def __init__(self, segments: Seg):
self.segments = segments
@ -151,7 +74,7 @@ class BaseDataset(Dataset, ABC):
self.total_samples = None
def load(self, load_path: str):
self.segments, self.total_samples = load_mmap_files(load_path)
self.segments, self.total_samples = MmapFileHander.load(load_path)
self.fetcher = MultiSegmentFetcher(self.segments)
def get_index(self, index: int) -> int:

92
khaosz/data/mmap.py Normal file
View File

@ -0,0 +1,92 @@
import os
import json
import torch
from torch import Tensor
from typing import List, Dict, Tuple
class MmapFileHander:
"""
json metadata like this:
```
[
{"file_name": "file1.bin", "size": 1000, "dtype": "float32", "key": "key1"},
{"file_name": "file2.bin", "size": 2000, "dtype": "float32", "key": "key2"}
...
]
```
files like:
```
file_mapper.json
file1.bin
file2.bin
...
```
"""
DTYPE_MAP = {
"float32": torch.float32,
"float64": torch.float64,
"int32": torch.int32,
"int64": torch.int64,
"bool": torch.bool,
}
REVERSE_DTYPE_MAP = {v: k for k, v in DTYPE_MAP.items()}
@staticmethod
def load(root_path: str, shared: bool=True) -> Tuple[Dict[str, List[Tensor]], int]:
metadata_list = []
mmap_shared_group: Dict[str, List[Tensor]] = {}
file_mapper_path = os.path.join(root_path, "file_mapper.json")
if not os.path.exists(file_mapper_path):
raise FileNotFoundError(f"File mapper not found: {file_mapper_path}")
with open(file_mapper_path, "r") as f:
metadata_list = json.load(f)
for metadata in metadata_list:
file_path = os.path.join(root_path, metadata["file_name"])
if not os.path.exists(file_path):
raise FileNotFoundError(f"Binary data file not found: {file_path}")
size = metadata["size"]
dtype = MmapFileHander.DTYPE_MAP[metadata["dtype"]]
segment_key = metadata["key"]
mmap_tensor = torch.from_file(file_path, shared=shared, size=size, dtype=dtype)
if segment_key not in mmap_shared_group:
mmap_shared_group[segment_key] = []
mmap_shared_group[segment_key].append(mmap_tensor)
num_samples = sum(metadata["size"] for metadata in metadata_list)
num_keys = len(set(metadata['key'] for metadata in metadata_list))
sample_per_key = num_samples / num_keys
return mmap_shared_group, sample_per_key
@staticmethod
def save(save_path: str, mmap_shared_group: Dict[str, List[Tensor]]) -> None:
os.makedirs(save_path, exist_ok=True)
metadata_list = []
for segment_key, segment_tensors in mmap_shared_group.items():
for idx, tensor in enumerate(segment_tensors):
metadata_list.append({
"file_name": f"{segment_key}_{idx}.bin",
"size": tensor.numel(),
"dtype": MmapFileHander.REVERSE_DTYPE_MAP[tensor.dtype],
"key": segment_key
})
file_path = os.path.join(save_path, f"{segment_key}_{idx}.bin")
with open(file_path, "wb") as f:
f.write(tensor.cpu().numpy().tobytes())
metadata_path = os.path.join(save_path, "file_mapper.json")
with open(metadata_path, "w") as f:
json.dump(metadata_list, f)