fix: 修复 load_h5 丢失文件的问题

This commit is contained in:
ViperEkura 2026-03-02 17:37:28 +08:00
parent 8a8d6369bc
commit dff58468d6
2 changed files with 14 additions and 10 deletions

View File

@ -13,11 +13,13 @@ class BaseSegmentFetcher:
def __init__(self, segments: List[Tensor]): def __init__(self, segments: List[Tensor]):
self.segments = segments self.segments = segments
self.cum_lengths = [] self.cum_lengths = []
total = 0 total = 0
for seg in segments: for seg in segments:
total += len(seg) total += torch.numel(seg)
self.cum_lengths.append(total) self.cum_lengths.append(total)
self.total_length = total if segments else 0
self.total_length = total
def __len__(self) -> int: def __len__(self) -> int:
return self.total_length return self.total_length

View File

@ -1,11 +1,10 @@
import os import os
import h5py import h5py
import numpy as np
import torch import torch
from pathlib import Path from pathlib import Path
from torch import Tensor from torch import Tensor
from typing import Dict, List, Tuple from typing import Dict, List
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
@ -14,15 +13,12 @@ def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]
with h5py.File(full_file_path, 'w') as f: with h5py.File(full_file_path, 'w') as f:
for key, tensors in tensor_group.items(): for key, tensors in tensor_group.items():
grp = f.create_group(key) grp = f.create_group(key)
grp.attrs['num_tensors'] = len(tensors)
for idx, tensor in enumerate(tensors): for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy() arr = tensor.cpu().numpy()
grp.create_dataset(f'data_{idx}', data=arr) grp.create_dataset(f'data_{idx}', data=arr)
def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]: def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {} tensor_group: Dict[str, List[Tensor]] = {}
total_samples = 0
root_path = Path(file_path) root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5")) h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
@ -34,7 +30,13 @@ def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
dsets = [] dsets = []
for dset_name in grp.keys(): for dset_name in grp.keys():
dset = grp[dset_name] dset = grp[dset_name]
dsets.append(torch.from_numpy(dset[:]).share_memory_()) tensor = torch.from_numpy(dset[:])
tensor_group[key] = dsets if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
return tensor_group return tensor_group