feat(data): 重构数据集加载逻辑,修复计数错误
This commit is contained in:
parent
567c55685e
commit
3ee84b31a0
|
|
@ -5,8 +5,7 @@ from khaosz.data.dataset import (
|
||||||
SftDataset,
|
SftDataset,
|
||||||
PpoDataset,
|
PpoDataset,
|
||||||
MultiSegmentFetcher,
|
MultiSegmentFetcher,
|
||||||
DatasetLoader,
|
DatasetLoader
|
||||||
load_pkl_files,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from khaosz.data.tokenizer import BpeTokenizer
|
from khaosz.data.tokenizer import BpeTokenizer
|
||||||
|
|
@ -20,7 +19,6 @@ __all__ = [
|
||||||
"PpoDataset",
|
"PpoDataset",
|
||||||
"MultiSegmentFetcher",
|
"MultiSegmentFetcher",
|
||||||
"DatasetLoader",
|
"DatasetLoader",
|
||||||
"load_pkl_files",
|
|
||||||
"BpeTokenizer",
|
"BpeTokenizer",
|
||||||
"ResumableDistributedSampler"
|
"ResumableDistributedSampler"
|
||||||
]
|
]
|
||||||
|
|
@ -67,8 +67,6 @@ def load_mmap_files(root_path: str, shared: bool=True) -> Tuple[MultiSeg, int]:
|
||||||
with open(file_mapper_path, "r") as f:
|
with open(file_mapper_path, "r") as f:
|
||||||
metadata_list = json.load(f)
|
metadata_list = json.load(f)
|
||||||
|
|
||||||
num_samples = sum(metadata["size"] for metadata in metadata_list)
|
|
||||||
|
|
||||||
for metadata in metadata_list:
|
for metadata in metadata_list:
|
||||||
file_path = os.path.join(root_path, metadata["file_name"])
|
file_path = os.path.join(root_path, metadata["file_name"])
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
|
|
@ -84,6 +82,9 @@ def load_mmap_files(root_path: str, shared: bool=True) -> Tuple[MultiSeg, int]:
|
||||||
|
|
||||||
mmap_shared_group[segment_key].append(mmap_tensor)
|
mmap_shared_group[segment_key].append(mmap_tensor)
|
||||||
|
|
||||||
|
num_samples = sum(metadata["size"] for metadata in metadata_list
|
||||||
|
if segment_key == metadata["key"])
|
||||||
|
|
||||||
return mmap_shared_group, num_samples
|
return mmap_shared_group, num_samples
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -142,16 +143,15 @@ class MultiSegmentFetcher:
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(Dataset, ABC):
|
class BaseDataset(Dataset, ABC):
|
||||||
def __init__(self, window_size: int, stride: int, share_memory: bool=False):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.segments: MultiSeg = {}
|
self.segments: MultiSeg = {}
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.total_samples = None
|
self.total_samples = None
|
||||||
|
|
||||||
def load(self, load_path: Union[str, List[str]]):
|
def load(self, load_path: str):
|
||||||
paths = [load_path] if isinstance(load_path, str) else load_path
|
self.segments, self.total_samples = load_mmap_files(load_path)
|
||||||
self.segments, self.total_samples = load_mmap_files(paths)
|
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def get_index(self, index: int) -> int:
|
def get_index(self, index: int) -> int:
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,7 @@ def base_test_env(request: pytest.FixtureRequest):
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"device": device,
|
"device": device,
|
||||||
"test_dir": test_dir,
|
"test_dir": str(test_dir),
|
||||||
"config_path": config_path,
|
"config_path": config_path,
|
||||||
"transformer_config": transformer_config,
|
"transformer_config": transformer_config,
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|
|
||||||
|
|
@ -1,67 +1,216 @@
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
import torch
|
import torch
|
||||||
import pickle
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
from khaosz.data.dataset import *
|
from khaosz.data.dataset import *
|
||||||
|
|
||||||
|
|
||||||
|
def create_mmap_dataset(dir_path, data_dict, dataset_name):
|
||||||
|
"""Helper function to create memory-mapped dataset for testing"""
|
||||||
|
dataset_dir = os.path.join(dir_path, dataset_name)
|
||||||
|
os.makedirs(dataset_dir, exist_ok=True)
|
||||||
|
|
||||||
|
file_mapper = []
|
||||||
|
|
||||||
|
for key, tensor in data_dict.items():
|
||||||
|
# Convert tensor to numpy array and save as binary file
|
||||||
|
np_array = tensor.numpy()
|
||||||
|
file_name = f"{key}.bin"
|
||||||
|
file_path = os.path.join(dataset_dir, file_name)
|
||||||
|
|
||||||
|
# Save as binary file
|
||||||
|
np_array.tofile(file_path)
|
||||||
|
|
||||||
|
# Add to file mapper
|
||||||
|
file_mapper.append({
|
||||||
|
"file_name": file_name,
|
||||||
|
"size": len(np_array),
|
||||||
|
"dtype": str(np_array.dtype),
|
||||||
|
"key": key
|
||||||
|
})
|
||||||
|
|
||||||
|
# Save file mapper
|
||||||
|
mapper_path = os.path.join(dataset_dir, "file_mapper.json")
|
||||||
|
with open(mapper_path, "w") as f:
|
||||||
|
json.dump(file_mapper, f, indent=2)
|
||||||
|
|
||||||
|
return dataset_dir
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_loader_random_paths(base_test_env):
|
def test_dataset_loader_random_paths(base_test_env):
|
||||||
"""Test dataset loader with multiple random paths"""
|
"""Test dataset loader with multiple random paths"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
# Create multiple pkl files with random data
|
# Create multiple mmap dataset directories with random data
|
||||||
num_files = np.random.randint(2, 5)
|
num_files = np.random.randint(2, 5)
|
||||||
pkl_paths = []
|
|
||||||
|
|
||||||
for i in range(num_files):
|
for i in range(num_files):
|
||||||
pkl_path = os.path.join(test_dir, f"test_data_{i}.pkl")
|
seq_length = np.random.randint(100, 200)
|
||||||
seq_length = np.random.randint(50, 100)
|
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"sequence": torch.randint(0, 1000, (seq_length,)),
|
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
||||||
"chosen": torch.randint(0, 1000, (seq_length,)),
|
|
||||||
"rejected": torch.randint(0, 1000, (seq_length,)),
|
|
||||||
"chosen_mask": torch.ones(seq_length, dtype=torch.bool),
|
|
||||||
"rejected_mask": torch.ones(seq_length, dtype=torch.bool)
|
|
||||||
}
|
}
|
||||||
with open(pkl_path, "wb") as f:
|
dataset_path = create_mmap_dataset(test_dir, dummy_data, f"test_data_{i}")
|
||||||
pickle.dump(dummy_data, f)
|
|
||||||
pkl_paths.append(pkl_path)
|
# Test loading with multiple paths
|
||||||
|
loaded_dataset = DatasetLoader.load(
|
||||||
|
train_type="seq",
|
||||||
|
load_path=dataset_path,
|
||||||
|
window_size=64,
|
||||||
|
)
|
||||||
|
assert loaded_dataset is not None
|
||||||
|
assert len(loaded_dataset) > 0
|
||||||
|
|
||||||
|
# Test that we can get items without errors
|
||||||
|
for i in range(min(3, len(loaded_dataset))):
|
||||||
|
item = loaded_dataset[i]
|
||||||
|
assert "input_ids" in item
|
||||||
|
assert "target_ids" in item
|
||||||
|
assert item["input_ids"].shape == item["target_ids"].shape
|
||||||
|
assert item["input_ids"].shape[0] == 64
|
||||||
|
|
||||||
# Test loading with multiple paths
|
|
||||||
loaded_dataset = DatasetLoader.load(
|
|
||||||
train_type="seq",
|
|
||||||
load_path=pkl_paths,
|
|
||||||
window_size=64,
|
|
||||||
)
|
|
||||||
assert loaded_dataset is not None
|
|
||||||
assert len(loaded_dataset) > 0
|
|
||||||
|
|
||||||
def test_dpo_strategy_with_random_data(base_test_env):
|
def test_dpo_strategy_with_random_data(base_test_env):
|
||||||
"""Test DPO strategy with randomized preference data"""
|
"""Test DPO strategy with randomized preference data"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
# Create DPO-style data
|
# Create DPO-style data with memory mapping format
|
||||||
pkl_path = os.path.join(test_dir, "dpo_data.pkl")
|
seq_length = np.random.randint(100, 200)
|
||||||
seq_length = np.random.randint(40, 80)
|
|
||||||
|
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"chosen": torch.randint(0, 1000, (seq_length,)),
|
"chosen": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
||||||
"rejected": torch.randint(0, 1000, (seq_length,)),
|
"rejected": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
||||||
"chosen_mask": torch.ones(seq_length, dtype=torch.bool),
|
"chosen_mask": torch.ones(seq_length, dtype=torch.bool),
|
||||||
"rejected_mask": torch.ones(seq_length, dtype=torch.bool)
|
"rejected_mask": torch.ones(seq_length, dtype=torch.bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(pkl_path, "wb") as f:
|
dataset_path = create_mmap_dataset(test_dir, dummy_data, "dpo_data")
|
||||||
pickle.dump(dummy_data, f)
|
|
||||||
|
|
||||||
# Load DPO dataset
|
# Load DPO dataset
|
||||||
dpo_dataset = DatasetLoader.load(
|
dpo_dataset = DatasetLoader.load(
|
||||||
train_type="dpo",
|
train_type="dpo",
|
||||||
load_path=pkl_path,
|
load_path=dataset_path,
|
||||||
window_size=64,
|
window_size=64,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert dpo_dataset is not None
|
assert dpo_dataset is not None
|
||||||
assert hasattr(dpo_dataset, 'fetcher')
|
assert hasattr(dpo_dataset, 'fetcher')
|
||||||
|
assert len(dpo_dataset) > 0
|
||||||
|
|
||||||
|
# Test that we can get DPO items without errors
|
||||||
|
for i in range(min(3, len(dpo_dataset))):
|
||||||
|
item = dpo_dataset[i]
|
||||||
|
assert "chosen" in item
|
||||||
|
assert "rejected" in item
|
||||||
|
assert "chosen_mask" in item
|
||||||
|
assert "rejected_mask" in item
|
||||||
|
assert item["chosen"].shape == item["rejected"].shape
|
||||||
|
assert item["chosen_mask"].shape == item["rejected_mask"].shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_sft_dataset_with_random_data(base_test_env):
|
||||||
|
"""Test SFT dataset with random data"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
# Create SFT-style data with memory mapping format
|
||||||
|
seq_length = np.random.randint(100, 200)
|
||||||
|
|
||||||
|
dummy_data = {
|
||||||
|
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
||||||
|
"loss_mask": torch.ones(seq_length, dtype=torch.bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset_path = create_mmap_dataset(test_dir, dummy_data, "sft_data")
|
||||||
|
|
||||||
|
# Load SFT dataset
|
||||||
|
sft_dataset = DatasetLoader.load(
|
||||||
|
train_type="sft",
|
||||||
|
load_path=dataset_path,
|
||||||
|
window_size=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sft_dataset is not None
|
||||||
|
assert hasattr(sft_dataset, 'fetcher')
|
||||||
|
assert len(sft_dataset) > 0
|
||||||
|
|
||||||
|
# Test that we can get SFT items without errors
|
||||||
|
for i in range(min(3, len(sft_dataset))):
|
||||||
|
item = sft_dataset[i]
|
||||||
|
assert "input_ids" in item
|
||||||
|
assert "target_ids" in item
|
||||||
|
assert "loss_mask" in item
|
||||||
|
assert item["input_ids"].shape == item["target_ids"].shape
|
||||||
|
assert item["loss_mask"].shape[0] == 64
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataset_with_custom_stride(base_test_env):
|
||||||
|
"""Test dataset with custom stride parameter"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
seq_length = 200
|
||||||
|
dummy_data = {
|
||||||
|
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset_path = create_mmap_dataset(test_dir, dummy_data, "stride_test_data")
|
||||||
|
|
||||||
|
# Test with custom stride
|
||||||
|
custom_stride = 32
|
||||||
|
dataset = DatasetLoader.load(
|
||||||
|
train_type="seq",
|
||||||
|
load_path=dataset_path,
|
||||||
|
window_size=64,
|
||||||
|
stride=custom_stride
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dataset is not None
|
||||||
|
assert len(dataset) > 0
|
||||||
|
|
||||||
|
# With stride 32 and window 64 on 200 length data, we should get more samples
|
||||||
|
# than with default stride (which equals window size)
|
||||||
|
default_stride_dataset = DatasetLoader.load(
|
||||||
|
train_type="seq",
|
||||||
|
load_path=dataset_path,
|
||||||
|
window_size=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(dataset) > len(default_stride_dataset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_segment_fetcher(base_test_env):
|
||||||
|
"""Test MultiSegmentFetcher functionality directly"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
# Create test data with multiple segments
|
||||||
|
seq_length = 100
|
||||||
|
dummy_data = {
|
||||||
|
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
||||||
|
"mask": torch.ones(seq_length, dtype=torch.bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test")
|
||||||
|
|
||||||
|
# Load the memory mapped files directly
|
||||||
|
multi_segments, _ = load_mmap_files(dataset_path)
|
||||||
|
|
||||||
|
# Create fetcher
|
||||||
|
fetcher = MultiSegmentFetcher(multi_segments)
|
||||||
|
|
||||||
|
# Test fetching single key
|
||||||
|
sequence_data = fetcher.key_fetch(0, 10, "sequence")
|
||||||
|
assert sequence_data is not None
|
||||||
|
assert len(sequence_data) == 10
|
||||||
|
|
||||||
|
# Test fetching multiple keys
|
||||||
|
multi_data = fetcher.key_fetch(0, 10, ["sequence", "mask"])
|
||||||
|
assert "sequence" in multi_data
|
||||||
|
assert "mask" in multi_data
|
||||||
|
assert len(multi_data["sequence"]) == 10
|
||||||
|
assert len(multi_data["mask"]) == 10
|
||||||
|
|
||||||
|
# Test fetching all keys
|
||||||
|
all_data = fetcher.fetch_data(0, 10)
|
||||||
|
assert "sequence" in all_data
|
||||||
|
assert "mask" in all_data
|
||||||
Loading…
Reference in New Issue