40 lines
1.3 KiB
Python
40 lines
1.3 KiB
Python
import os
|
|
import h5py
|
|
import numpy as np
|
|
import torch
|
|
|
|
from pathlib import Path
|
|
from torch import Tensor
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
|
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
|
os.makedirs(file_path, exist_ok=True)
|
|
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
|
with h5py.File(full_file_path, 'w') as f:
|
|
for key, tensors in tensor_group.items():
|
|
grp = f.create_group(key)
|
|
grp.attrs['num_tensors'] = len(tensors)
|
|
|
|
for idx, tensor in enumerate(tensors):
|
|
arr = tensor.cpu().numpy()
|
|
grp.create_dataset(f'data_{idx}', data=arr)
|
|
|
|
def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
|
|
tensor_group: Dict[str, List[Tensor]] = {}
|
|
total_samples = 0
|
|
|
|
root_path = Path(file_path)
|
|
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
|
|
|
for h5_file in h5_files:
|
|
with h5py.File(h5_file, 'r') as f:
|
|
for key in f.keys():
|
|
grp = f[key]
|
|
dsets = []
|
|
for dset_name in grp.keys():
|
|
dset = grp[dset_name]
|
|
dsets.append(torch.from_numpy(dset[:]).share_memory_())
|
|
tensor_group[key] = dsets
|
|
|
|
return tensor_group |