fix: 修复 load_h5 丢失文件的问题

This commit is contained in:
ViperEkura 2026-03-02 17:37:28 +08:00
parent 8a8d6369bc
commit dff58468d6
2 changed files with 14 additions and 10 deletions

View File

@ -13,11 +13,13 @@ class BaseSegmentFetcher:
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += len(seg)
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total if segments else 0
self.total_length = total
def __len__(self) -> int:
return self.total_length

View File

@ -1,11 +1,10 @@
import os
import h5py
import numpy as np
import torch
from pathlib import Path
from torch import Tensor
from typing import Dict, List, Tuple
from typing import Dict, List
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
@ -14,15 +13,12 @@ def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]
with h5py.File(full_file_path, 'w') as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
grp.attrs['num_tensors'] = len(tensors)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
grp.create_dataset(f'data_{idx}', data=arr)
def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {}
total_samples = 0
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
@ -34,7 +30,13 @@ def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
dsets.append(torch.from_numpy(dset[:]).share_memory_())
tensor_group[key] = dsets
tensor = torch.from_numpy(dset[:])
if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
return tensor_group