From 3ee84b31a049bc01ad1eaab7aacfd3b64196edb2 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 28 Nov 2025 20:59:24 +0800 Subject: [PATCH] =?UTF-8?q?feat(data):=20=E9=87=8D=E6=9E=84=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E9=9B=86=E5=8A=A0=E8=BD=BD=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=AE=A1=E6=95=B0=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/data/__init__.py | 4 +- khaosz/data/dataset.py | 12 +- tests/conftest.py | 2 +- tests/test_dataset_loader.py | 209 ++++++++++++++++++++++++++++++----- 4 files changed, 187 insertions(+), 40 deletions(-) diff --git a/khaosz/data/__init__.py b/khaosz/data/__init__.py index 4d1930c..38164f7 100644 --- a/khaosz/data/__init__.py +++ b/khaosz/data/__init__.py @@ -5,8 +5,7 @@ from khaosz.data.dataset import ( SftDataset, PpoDataset, MultiSegmentFetcher, - DatasetLoader, - load_pkl_files, + DatasetLoader ) from khaosz.data.tokenizer import BpeTokenizer @@ -20,7 +19,6 @@ __all__ = [ "PpoDataset", "MultiSegmentFetcher", "DatasetLoader", - "load_pkl_files", "BpeTokenizer", "ResumableDistributedSampler" ] \ No newline at end of file diff --git a/khaosz/data/dataset.py b/khaosz/data/dataset.py index ba540a4..5219e16 100644 --- a/khaosz/data/dataset.py +++ b/khaosz/data/dataset.py @@ -67,8 +67,6 @@ def load_mmap_files(root_path: str, shared: bool=True) -> Tuple[MultiSeg, int]: with open(file_mapper_path, "r") as f: metadata_list = json.load(f) - num_samples = sum(metadata["size"] for metadata in metadata_list) - for metadata in metadata_list: file_path = os.path.join(root_path, metadata["file_name"]) if not os.path.exists(file_path): @@ -84,6 +82,9 @@ def load_mmap_files(root_path: str, shared: bool=True) -> Tuple[MultiSeg, int]: mmap_shared_group[segment_key].append(mmap_tensor) + num_samples = sum(metadata["size"] for metadata in metadata_list + if segment_key == metadata["key"]) + return mmap_shared_group, num_samples @@ -142,16 +143,15 @@ class MultiSegmentFetcher: class BaseDataset(Dataset, ABC): - def __init__(self, window_size: int, stride: int, share_memory: bool=False): + def __init__(self, window_size: int, stride: int): super().__init__() self.segments: MultiSeg = {} self.window_size = window_size self.stride = stride self.total_samples = None - def load(self, load_path: Union[str, List[str]]): - paths = [load_path] if isinstance(load_path, str) else load_path - self.segments, self.total_samples = load_mmap_files(paths) + def load(self, load_path: str): + self.segments, self.total_samples = load_mmap_files(load_path) self.fetcher = MultiSegmentFetcher(self.segments) def get_index(self, index: int) -> int: diff --git a/tests/conftest.py b/tests/conftest.py index 9ab3177..093fb37 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -108,7 +108,7 @@ def base_test_env(request: pytest.FixtureRequest): yield { "device": device, - "test_dir": test_dir, + "test_dir": str(test_dir), "config_path": config_path, "transformer_config": transformer_config, "model": model, diff --git a/tests/test_dataset_loader.py b/tests/test_dataset_loader.py index a6e282c..3187fe4 100644 --- a/tests/test_dataset_loader.py +++ b/tests/test_dataset_loader.py @@ -1,67 +1,216 @@ import os +import json import torch -import pickle import numpy as np from khaosz.trainer import * from khaosz.data.dataset import * +def create_mmap_dataset(dir_path, data_dict, dataset_name): + """Helper function to create memory-mapped dataset for testing""" + dataset_dir = os.path.join(dir_path, dataset_name) + os.makedirs(dataset_dir, exist_ok=True) + + file_mapper = [] + + for key, tensor in data_dict.items(): + # Convert tensor to numpy array and save as binary file + np_array = tensor.numpy() + file_name = f"{key}.bin" + file_path = os.path.join(dataset_dir, file_name) + + # Save as binary file + np_array.tofile(file_path) + + # Add to file mapper + file_mapper.append({ + "file_name": file_name, + "size": len(np_array), + "dtype": str(np_array.dtype), + "key": key + }) + + # Save file mapper + mapper_path = os.path.join(dataset_dir, "file_mapper.json") + with open(mapper_path, "w") as f: + json.dump(file_mapper, f, indent=2) + + return dataset_dir + + def test_dataset_loader_random_paths(base_test_env): """Test dataset loader with multiple random paths""" test_dir = base_test_env["test_dir"] - # Create multiple pkl files with random data + # Create multiple mmap dataset directories with random data num_files = np.random.randint(2, 5) - pkl_paths = [] for i in range(num_files): - pkl_path = os.path.join(test_dir, f"test_data_{i}.pkl") - seq_length = np.random.randint(50, 100) + seq_length = np.random.randint(100, 200) dummy_data = { - "sequence": torch.randint(0, 1000, (seq_length,)), - "chosen": torch.randint(0, 1000, (seq_length,)), - "rejected": torch.randint(0, 1000, (seq_length,)), - "chosen_mask": torch.ones(seq_length, dtype=torch.bool), - "rejected_mask": torch.ones(seq_length, dtype=torch.bool) + "sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), } - with open(pkl_path, "wb") as f: - pickle.dump(dummy_data, f) - pkl_paths.append(pkl_path) + dataset_path = create_mmap_dataset(test_dir, dummy_data, f"test_data_{i}") - # Test loading with multiple paths - loaded_dataset = DatasetLoader.load( - train_type="seq", - load_path=pkl_paths, - window_size=64, - ) - assert loaded_dataset is not None - assert len(loaded_dataset) > 0 + # Test loading with multiple paths + loaded_dataset = DatasetLoader.load( + train_type="seq", + load_path=dataset_path, + window_size=64, + ) + assert loaded_dataset is not None + assert len(loaded_dataset) > 0 + + # Test that we can get items without errors + for i in range(min(3, len(loaded_dataset))): + item = loaded_dataset[i] + assert "input_ids" in item + assert "target_ids" in item + assert item["input_ids"].shape == item["target_ids"].shape + assert item["input_ids"].shape[0] == 64 + def test_dpo_strategy_with_random_data(base_test_env): """Test DPO strategy with randomized preference data""" test_dir = base_test_env["test_dir"] - # Create DPO-style data - pkl_path = os.path.join(test_dir, "dpo_data.pkl") - seq_length = np.random.randint(40, 80) + # Create DPO-style data with memory mapping format + seq_length = np.random.randint(100, 200) dummy_data = { - "chosen": torch.randint(0, 1000, (seq_length,)), - "rejected": torch.randint(0, 1000, (seq_length,)), + "chosen": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), + "rejected": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), "chosen_mask": torch.ones(seq_length, dtype=torch.bool), "rejected_mask": torch.ones(seq_length, dtype=torch.bool) } - with open(pkl_path, "wb") as f: - pickle.dump(dummy_data, f) + dataset_path = create_mmap_dataset(test_dir, dummy_data, "dpo_data") # Load DPO dataset dpo_dataset = DatasetLoader.load( train_type="dpo", - load_path=pkl_path, + load_path=dataset_path, window_size=64, ) assert dpo_dataset is not None - assert hasattr(dpo_dataset, 'fetcher') \ No newline at end of file + assert hasattr(dpo_dataset, 'fetcher') + assert len(dpo_dataset) > 0 + + # Test that we can get DPO items without errors + for i in range(min(3, len(dpo_dataset))): + item = dpo_dataset[i] + assert "chosen" in item + assert "rejected" in item + assert "chosen_mask" in item + assert "rejected_mask" in item + assert item["chosen"].shape == item["rejected"].shape + assert item["chosen_mask"].shape == item["rejected_mask"].shape + + +def test_sft_dataset_with_random_data(base_test_env): + """Test SFT dataset with random data""" + test_dir = base_test_env["test_dir"] + + # Create SFT-style data with memory mapping format + seq_length = np.random.randint(100, 200) + + dummy_data = { + "sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), + "loss_mask": torch.ones(seq_length, dtype=torch.bool) + } + + dataset_path = create_mmap_dataset(test_dir, dummy_data, "sft_data") + + # Load SFT dataset + sft_dataset = DatasetLoader.load( + train_type="sft", + load_path=dataset_path, + window_size=64, + ) + + assert sft_dataset is not None + assert hasattr(sft_dataset, 'fetcher') + assert len(sft_dataset) > 0 + + # Test that we can get SFT items without errors + for i in range(min(3, len(sft_dataset))): + item = sft_dataset[i] + assert "input_ids" in item + assert "target_ids" in item + assert "loss_mask" in item + assert item["input_ids"].shape == item["target_ids"].shape + assert item["loss_mask"].shape[0] == 64 + + +def test_dataset_with_custom_stride(base_test_env): + """Test dataset with custom stride parameter""" + test_dir = base_test_env["test_dir"] + + # Create test data + seq_length = 200 + dummy_data = { + "sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), + } + + dataset_path = create_mmap_dataset(test_dir, dummy_data, "stride_test_data") + + # Test with custom stride + custom_stride = 32 + dataset = DatasetLoader.load( + train_type="seq", + load_path=dataset_path, + window_size=64, + stride=custom_stride + ) + + assert dataset is not None + assert len(dataset) > 0 + + # With stride 32 and window 64 on 200 length data, we should get more samples + # than with default stride (which equals window size) + default_stride_dataset = DatasetLoader.load( + train_type="seq", + load_path=dataset_path, + window_size=64, + ) + + assert len(dataset) > len(default_stride_dataset) + + +def test_multi_segment_fetcher(base_test_env): + """Test MultiSegmentFetcher functionality directly""" + test_dir = base_test_env["test_dir"] + + # Create test data with multiple segments + seq_length = 100 + dummy_data = { + "sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), + "mask": torch.ones(seq_length, dtype=torch.bool) + } + + dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test") + + # Load the memory mapped files directly + multi_segments, _ = load_mmap_files(dataset_path) + + # Create fetcher + fetcher = MultiSegmentFetcher(multi_segments) + + # Test fetching single key + sequence_data = fetcher.key_fetch(0, 10, "sequence") + assert sequence_data is not None + assert len(sequence_data) == 10 + + # Test fetching multiple keys + multi_data = fetcher.key_fetch(0, 10, ["sequence", "mask"]) + assert "sequence" in multi_data + assert "mask" in multi_data + assert len(multi_data["sequence"]) == 10 + assert len(multi_data["mask"]) == 10 + + # Test fetching all keys + all_data = fetcher.fetch_data(0, 10) + assert "sequence" in all_data + assert "mask" in all_data \ No newline at end of file