refactor(data): 将内存映射文件加载逻辑移至独立的 MmapFileHander 类
This commit is contained in:
parent
d882f65579
commit
701fb9bf78
|
|
@ -1,93 +1,16 @@
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import torch
|
import torch
|
||||||
import bisect
|
import bisect
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from typing import Callable, List, Dict, Literal, Optional, Tuple, Union
|
from khaosz.data.mmap import MmapFileHander
|
||||||
|
from typing import Callable, List, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
Seg = List[Tensor]
|
Seg = List[Tensor]
|
||||||
MultiSeg = Dict[str, Seg]
|
MultiSeg = Dict[str, Seg]
|
||||||
|
|
||||||
|
|
||||||
def load_mmap_files(root_path: str, shared: bool=True) -> Tuple[MultiSeg, int]:
|
|
||||||
"""Load memory-mapped binary files as torch tensors.
|
|
||||||
|
|
||||||
Loads configuration from file_mapper.json in the specified directory, then loads
|
|
||||||
corresponding binary files as memory-mapped tensors. Returns tensors grouped by key
|
|
||||||
and total number of elements.
|
|
||||||
|
|
||||||
json metadata like this:
|
|
||||||
|
|
||||||
```
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"file_name": "file1.bin",
|
|
||||||
"size": 1000,
|
|
||||||
"dtype": "float32",
|
|
||||||
"key": "key1"
|
|
||||||
},
|
|
||||||
...
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
root_path: Root directory path containing file_mapper.json and binary files
|
|
||||||
shared: Whether to load tensors in shared mode. If True, tensors can be
|
|
||||||
shared between processes
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: If file_mapper.json or any binary file in config is missing
|
|
||||||
KeyError: If dtype in config is not in supported DTYPE_MAP
|
|
||||||
json.JSONDecodeError: If config file is not valid JSON
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple containing:
|
|
||||||
- MultiSeg: Dictionary of tensors grouped by key, structure: {key: [tensor1, tensor2, ...]}
|
|
||||||
- int: Total number of elements across all tensors
|
|
||||||
"""
|
|
||||||
|
|
||||||
DTYPE_MAP = {
|
|
||||||
"float32": torch.float32,
|
|
||||||
"float64": torch.float64,
|
|
||||||
"int32": torch.int32,
|
|
||||||
"int64": torch.int64,
|
|
||||||
"bool": torch.bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
metadata_list = []
|
|
||||||
mmap_shared_group: MultiSeg = {}
|
|
||||||
|
|
||||||
file_mapper_path = os.path.join(root_path, "file_mapper.json")
|
|
||||||
if not os.path.exists(file_mapper_path):
|
|
||||||
raise FileNotFoundError(f"File mapper not found: {file_mapper_path}")
|
|
||||||
|
|
||||||
with open(file_mapper_path, "r") as f:
|
|
||||||
metadata_list = json.load(f)
|
|
||||||
|
|
||||||
for metadata in metadata_list:
|
|
||||||
file_path = os.path.join(root_path, metadata["file_name"])
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise FileNotFoundError(f"Binary data file not found: {file_path}")
|
|
||||||
|
|
||||||
size = metadata["size"]
|
|
||||||
dtype = DTYPE_MAP[metadata["dtype"]]
|
|
||||||
segment_key = metadata["key"]
|
|
||||||
mmap_tensor = torch.from_file(file_path, shared=shared, size=size, dtype=dtype)
|
|
||||||
|
|
||||||
if segment_key not in mmap_shared_group:
|
|
||||||
mmap_shared_group[segment_key] = []
|
|
||||||
|
|
||||||
mmap_shared_group[segment_key].append(mmap_tensor)
|
|
||||||
|
|
||||||
num_samples = sum(metadata["size"] for metadata in metadata_list
|
|
||||||
if segment_key == metadata["key"])
|
|
||||||
|
|
||||||
return mmap_shared_group, num_samples
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSegmentFetcher:
|
class BaseSegmentFetcher:
|
||||||
def __init__(self, segments: Seg):
|
def __init__(self, segments: Seg):
|
||||||
self.segments = segments
|
self.segments = segments
|
||||||
|
|
@ -151,7 +74,7 @@ class BaseDataset(Dataset, ABC):
|
||||||
self.total_samples = None
|
self.total_samples = None
|
||||||
|
|
||||||
def load(self, load_path: str):
|
def load(self, load_path: str):
|
||||||
self.segments, self.total_samples = load_mmap_files(load_path)
|
self.segments, self.total_samples = MmapFileHander.load(load_path)
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def get_index(self, index: int) -> int:
|
def get_index(self, index: int) -> int:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,92 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import List, Dict, Tuple
|
||||||
|
|
||||||
|
class MmapFileHander:
|
||||||
|
"""
|
||||||
|
json metadata like this:
|
||||||
|
|
||||||
|
```
|
||||||
|
[
|
||||||
|
{"file_name": "file1.bin", "size": 1000, "dtype": "float32", "key": "key1"},
|
||||||
|
{"file_name": "file2.bin", "size": 2000, "dtype": "float32", "key": "key2"}
|
||||||
|
...
|
||||||
|
]
|
||||||
|
```
|
||||||
|
files like:
|
||||||
|
|
||||||
|
```
|
||||||
|
file_mapper.json
|
||||||
|
file1.bin
|
||||||
|
file2.bin
|
||||||
|
...
|
||||||
|
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
DTYPE_MAP = {
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float64": torch.float64,
|
||||||
|
"int32": torch.int32,
|
||||||
|
"int64": torch.int64,
|
||||||
|
"bool": torch.bool,
|
||||||
|
}
|
||||||
|
REVERSE_DTYPE_MAP = {v: k for k, v in DTYPE_MAP.items()}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(root_path: str, shared: bool=True) -> Tuple[Dict[str, List[Tensor]], int]:
|
||||||
|
metadata_list = []
|
||||||
|
mmap_shared_group: Dict[str, List[Tensor]] = {}
|
||||||
|
|
||||||
|
file_mapper_path = os.path.join(root_path, "file_mapper.json")
|
||||||
|
if not os.path.exists(file_mapper_path):
|
||||||
|
raise FileNotFoundError(f"File mapper not found: {file_mapper_path}")
|
||||||
|
|
||||||
|
with open(file_mapper_path, "r") as f:
|
||||||
|
metadata_list = json.load(f)
|
||||||
|
|
||||||
|
for metadata in metadata_list:
|
||||||
|
file_path = os.path.join(root_path, metadata["file_name"])
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"Binary data file not found: {file_path}")
|
||||||
|
|
||||||
|
size = metadata["size"]
|
||||||
|
dtype = MmapFileHander.DTYPE_MAP[metadata["dtype"]]
|
||||||
|
segment_key = metadata["key"]
|
||||||
|
mmap_tensor = torch.from_file(file_path, shared=shared, size=size, dtype=dtype)
|
||||||
|
|
||||||
|
if segment_key not in mmap_shared_group:
|
||||||
|
mmap_shared_group[segment_key] = []
|
||||||
|
|
||||||
|
mmap_shared_group[segment_key].append(mmap_tensor)
|
||||||
|
|
||||||
|
num_samples = sum(metadata["size"] for metadata in metadata_list)
|
||||||
|
num_keys = len(set(metadata['key'] for metadata in metadata_list))
|
||||||
|
|
||||||
|
sample_per_key = num_samples / num_keys
|
||||||
|
|
||||||
|
return mmap_shared_group, sample_per_key
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save(save_path: str, mmap_shared_group: Dict[str, List[Tensor]]) -> None:
|
||||||
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
|
||||||
|
metadata_list = []
|
||||||
|
for segment_key, segment_tensors in mmap_shared_group.items():
|
||||||
|
for idx, tensor in enumerate(segment_tensors):
|
||||||
|
metadata_list.append({
|
||||||
|
"file_name": f"{segment_key}_{idx}.bin",
|
||||||
|
"size": tensor.numel(),
|
||||||
|
"dtype": MmapFileHander.REVERSE_DTYPE_MAP[tensor.dtype],
|
||||||
|
"key": segment_key
|
||||||
|
})
|
||||||
|
file_path = os.path.join(save_path, f"{segment_key}_{idx}.bin")
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(tensor.cpu().numpy().tobytes())
|
||||||
|
|
||||||
|
metadata_path = os.path.join(save_path, "file_mapper.json")
|
||||||
|
with open(metadata_path, "w") as f:
|
||||||
|
json.dump(metadata_list, f)
|
||||||
Loading…
Reference in New Issue