fix: 修复 load_h5 丢失文件的问题
This commit is contained in:
parent
8a8d6369bc
commit
dff58468d6
|
|
@ -13,11 +13,13 @@ class BaseSegmentFetcher:
|
|||
def __init__(self, segments: List[Tensor]):
|
||||
self.segments = segments
|
||||
self.cum_lengths = []
|
||||
|
||||
total = 0
|
||||
for seg in segments:
|
||||
total += len(seg)
|
||||
total += torch.numel(seg)
|
||||
self.cum_lengths.append(total)
|
||||
self.total_length = total if segments else 0
|
||||
|
||||
self.total_length = total
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.total_length
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
import os
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pathlib import Path
|
||||
from torch import Tensor
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
|
|
@ -14,15 +13,12 @@ def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]
|
|||
with h5py.File(full_file_path, 'w') as f:
|
||||
for key, tensors in tensor_group.items():
|
||||
grp = f.create_group(key)
|
||||
grp.attrs['num_tensors'] = len(tensors)
|
||||
|
||||
for idx, tensor in enumerate(tensors):
|
||||
arr = tensor.cpu().numpy()
|
||||
grp.create_dataset(f'data_{idx}', data=arr)
|
||||
|
||||
def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
|
||||
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||
tensor_group: Dict[str, List[Tensor]] = {}
|
||||
total_samples = 0
|
||||
|
||||
root_path = Path(file_path)
|
||||
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||
|
|
@ -34,7 +30,13 @@ def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
|
|||
dsets = []
|
||||
for dset_name in grp.keys():
|
||||
dset = grp[dset_name]
|
||||
dsets.append(torch.from_numpy(dset[:]).share_memory_())
|
||||
tensor_group[key] = dsets
|
||||
tensor = torch.from_numpy(dset[:])
|
||||
if share_memory:
|
||||
tensor = tensor.share_memory_()
|
||||
dsets.append(tensor)
|
||||
|
||||
if tensor_group.get(key) is None:
|
||||
tensor_group[key] = []
|
||||
tensor_group[key].extend(dsets)
|
||||
|
||||
return tensor_group
|
||||
Loading…
Reference in New Issue