From dff58468d6c47cb18d9d94e44c1a5a80822f1455 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 2 Mar 2026 17:37:28 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20load=5Fh5=20?= =?UTF-8?q?=E4=B8=A2=E5=A4=B1=E6=96=87=E4=BB=B6=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/data/dataset.py | 6 ++++-- khaosz/data/file.py | 18 ++++++++++-------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/khaosz/data/dataset.py b/khaosz/data/dataset.py index 1255b8c..c15bc54 100644 --- a/khaosz/data/dataset.py +++ b/khaosz/data/dataset.py @@ -13,11 +13,13 @@ class BaseSegmentFetcher: def __init__(self, segments: List[Tensor]): self.segments = segments self.cum_lengths = [] + total = 0 for seg in segments: - total += len(seg) + total += torch.numel(seg) self.cum_lengths.append(total) - self.total_length = total if segments else 0 + + self.total_length = total def __len__(self) -> int: return self.total_length diff --git a/khaosz/data/file.py b/khaosz/data/file.py index 23bad89..e6d5535 100644 --- a/khaosz/data/file.py +++ b/khaosz/data/file.py @@ -1,11 +1,10 @@ import os import h5py -import numpy as np import torch from pathlib import Path from torch import Tensor -from typing import Dict, List, Tuple +from typing import Dict, List def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): @@ -14,15 +13,12 @@ def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor] with h5py.File(full_file_path, 'w') as f: for key, tensors in tensor_group.items(): grp = f.create_group(key) - grp.attrs['num_tensors'] = len(tensors) - for idx, tensor in enumerate(tensors): arr = tensor.cpu().numpy() grp.create_dataset(f'data_{idx}', data=arr) -def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]: +def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]: tensor_group: Dict[str, List[Tensor]] = {} - total_samples = 0 root_path = Path(file_path) h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5")) @@ -34,7 +30,13 @@ def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]: dsets = [] for dset_name in grp.keys(): dset = grp[dset_name] - dsets.append(torch.from_numpy(dset[:]).share_memory_()) - tensor_group[key] = dsets + tensor = torch.from_numpy(dset[:]) + if share_memory: + tensor = tensor.share_memory_() + dsets.append(tensor) + + if tensor_group.get(key) is None: + tensor_group[key] = [] + tensor_group[key].extend(dsets) return tensor_group \ No newline at end of file