From 5d3799b7159738abcc8c68be7d0f989ae562a0b8 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 30 Oct 2025 16:32:25 +0800 Subject: [PATCH] =?UTF-8?q?refactor(data):=20=E4=BF=AE=E6=94=B9=E5=8F=98?= =?UTF-8?q?=E9=87=8F=E5=91=BD=E5=90=8D=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/data/data_util.py | 46 ++++++++++++++++++++-------------------- train.py | 16 +++++++------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/khaosz/data/data_util.py b/khaosz/data/data_util.py index 1053d48..d3608ad 100644 --- a/khaosz/data/data_util.py +++ b/khaosz/data/data_util.py @@ -81,11 +81,11 @@ class MutiSegmentFetcher: class BaseDataset(Dataset, ABC): - def __init__(self, chunk_size: int, step_size: int): + def __init__(self, window_size: int, stride: int): super().__init__() self.segments: MutiSeg = {} - self.chunk_size = chunk_size - self.step_size = step_size + self.window_size = window_size + self.stride = stride self.total_samples = None def save(self, save_path: str): @@ -106,8 +106,8 @@ class BaseDataset(Dataset, ABC): self.fetcher = MutiSegmentFetcher(self.segments) def get_index(self, index: int) -> int: - begin_idx = min(index * self.step_size, self.total_samples - self.chunk_size - 1) - end_idx = begin_idx + self.chunk_size + begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1) + end_idx = begin_idx + self.window_size return begin_idx, end_idx @@ -117,14 +117,14 @@ class BaseDataset(Dataset, ABC): def __len__(self) -> int: assert self.total_samples is not None - if self.total_samples <= self.chunk_size: + if self.total_samples <= self.window_size: return 0 - return self.total_samples // self.step_size + 1 + return self.total_samples // self.stride + 1 class SeqDataset(BaseDataset): - def __init__(self, chunk_size: int, step_size: int): - super().__init__(chunk_size, step_size) + def __init__(self, window_size: int, stride: int): + super().__init__(window_size, stride) self.fetcher = MutiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: @@ -141,8 +141,8 @@ class SeqDataset(BaseDataset): class SftDataset(BaseDataset): - def __init__(self, chunk_size: int, step_size: int): - super().__init__(chunk_size, step_size) + def __init__(self, window_size: int, stride: int): + super().__init__(window_size, stride) self.fetcher = MutiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: @@ -159,8 +159,8 @@ class SftDataset(BaseDataset): class DpoDataset(BaseDataset): - def __init__(self, chunk_size: int, step_size: int): - super().__init__(chunk_size, step_size) + def __init__(self, window_size: int, stride: int): + super().__init__(window_size, stride) self.fetcher = MutiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: @@ -178,8 +178,8 @@ class DpoDataset(BaseDataset): class PpoDataset(BaseDataset): - def __init__(self, chunk_size: int, step_size: int): - super().__init__(chunk_size, step_size) + def __init__(self, window_size: int, stride: int): + super().__init__(window_size, stride) self.fetcher = MutiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: @@ -201,19 +201,19 @@ class DatasetLoader: def load( train_type: Literal["seq", "sft", "dpo"], load_path: Union[str, List[str]], - max_len: int, - step_size: Optional[int] = None, + window_size: int, + stride: Optional[int] = None, **kwargs ) -> BaseDataset: - if step_size is None: - step_size = max_len + if stride is None: + stride = window_size dataset_router: Dict[str, Callable[[int], BaseDataset]] = { - "seq": lambda max_len: SeqDataset(max_len, step_size), - "sft": lambda max_len: SftDataset(max_len, step_size), - "dpo": lambda max_len: DpoDataset(max_len, step_size), + "seq": lambda window_size: SeqDataset(window_size, stride), + "sft": lambda window_size: SftDataset(window_size, stride), + "dpo": lambda window_size: DpoDataset(window_size, stride), } - dataset = dataset_router[train_type](max_len) + dataset = dataset_router[train_type](window_size) dataset.load(load_path) return dataset diff --git a/train.py b/train.py index 6c337b0..c498d37 100644 --- a/train.py +++ b/train.py @@ -36,8 +36,8 @@ def train( max_grad_norm: float, embdeding_lr_rate: int, random_seed: int, - max_len: int, - step_size: int, + window_size: int, + stride: int, resume_from_checkpoint: bool ): assert train_type in ["seq", "sft", "dpo"] @@ -49,8 +49,8 @@ def train( if isinstance(parameter, Checkpoint) and resume_from_checkpoint: checkpoint = parameter - if max_len is None: - max_len = parameter.config.m_len + if window_size is None: + window_size = parameter.config.m_len model = parameter.model device = torch.device("cuda") @@ -74,8 +74,8 @@ def train( dataset = DatasetLoader.load( train_type=train_type, load_path=cache_files, - max_len=max_len, - step_size=step_size, + window_size=window_size, + stride=stride, **kwargs ) @@ -140,8 +140,8 @@ if __name__ == "__main__": parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") # other configs - parser.add_argument("--max_len", type=int, default=None, help="the max length of the input sequence.") - parser.add_argument("--step_size", type=int, default=None, help="the step size of the input sequence.") + parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.") + parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.") parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.") parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.") parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")