refactor(data): 修改变量命名方式

This commit is contained in:
ViperEkura 2025-10-30 16:32:25 +08:00
parent 6a3135f401
commit 5d3799b715
2 changed files with 31 additions and 31 deletions

View File

@ -81,11 +81,11 @@ class MutiSegmentFetcher:
class BaseDataset(Dataset, ABC): class BaseDataset(Dataset, ABC):
def __init__(self, chunk_size: int, step_size: int): def __init__(self, window_size: int, stride: int):
super().__init__() super().__init__()
self.segments: MutiSeg = {} self.segments: MutiSeg = {}
self.chunk_size = chunk_size self.window_size = window_size
self.step_size = step_size self.stride = stride
self.total_samples = None self.total_samples = None
def save(self, save_path: str): def save(self, save_path: str):
@ -106,8 +106,8 @@ class BaseDataset(Dataset, ABC):
self.fetcher = MutiSegmentFetcher(self.segments) self.fetcher = MutiSegmentFetcher(self.segments)
def get_index(self, index: int) -> int: def get_index(self, index: int) -> int:
begin_idx = min(index * self.step_size, self.total_samples - self.chunk_size - 1) begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1)
end_idx = begin_idx + self.chunk_size end_idx = begin_idx + self.window_size
return begin_idx, end_idx return begin_idx, end_idx
@ -117,14 +117,14 @@ class BaseDataset(Dataset, ABC):
def __len__(self) -> int: def __len__(self) -> int:
assert self.total_samples is not None assert self.total_samples is not None
if self.total_samples <= self.chunk_size: if self.total_samples <= self.window_size:
return 0 return 0
return self.total_samples // self.step_size + 1 return self.total_samples // self.stride + 1
class SeqDataset(BaseDataset): class SeqDataset(BaseDataset):
def __init__(self, chunk_size: int, step_size: int): def __init__(self, window_size: int, stride: int):
super().__init__(chunk_size, step_size) super().__init__(window_size, stride)
self.fetcher = MutiSegmentFetcher(self.segments) self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
@ -141,8 +141,8 @@ class SeqDataset(BaseDataset):
class SftDataset(BaseDataset): class SftDataset(BaseDataset):
def __init__(self, chunk_size: int, step_size: int): def __init__(self, window_size: int, stride: int):
super().__init__(chunk_size, step_size) super().__init__(window_size, stride)
self.fetcher = MutiSegmentFetcher(self.segments) self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
@ -159,8 +159,8 @@ class SftDataset(BaseDataset):
class DpoDataset(BaseDataset): class DpoDataset(BaseDataset):
def __init__(self, chunk_size: int, step_size: int): def __init__(self, window_size: int, stride: int):
super().__init__(chunk_size, step_size) super().__init__(window_size, stride)
self.fetcher = MutiSegmentFetcher(self.segments) self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
@ -178,8 +178,8 @@ class DpoDataset(BaseDataset):
class PpoDataset(BaseDataset): class PpoDataset(BaseDataset):
def __init__(self, chunk_size: int, step_size: int): def __init__(self, window_size: int, stride: int):
super().__init__(chunk_size, step_size) super().__init__(window_size, stride)
self.fetcher = MutiSegmentFetcher(self.segments) self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
@ -201,19 +201,19 @@ class DatasetLoader:
def load( def load(
train_type: Literal["seq", "sft", "dpo"], train_type: Literal["seq", "sft", "dpo"],
load_path: Union[str, List[str]], load_path: Union[str, List[str]],
max_len: int, window_size: int,
step_size: Optional[int] = None, stride: Optional[int] = None,
**kwargs **kwargs
) -> BaseDataset: ) -> BaseDataset:
if step_size is None: if stride is None:
step_size = max_len stride = window_size
dataset_router: Dict[str, Callable[[int], BaseDataset]] = { dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
"seq": lambda max_len: SeqDataset(max_len, step_size), "seq": lambda window_size: SeqDataset(window_size, stride),
"sft": lambda max_len: SftDataset(max_len, step_size), "sft": lambda window_size: SftDataset(window_size, stride),
"dpo": lambda max_len: DpoDataset(max_len, step_size), "dpo": lambda window_size: DpoDataset(window_size, stride),
} }
dataset = dataset_router[train_type](max_len) dataset = dataset_router[train_type](window_size)
dataset.load(load_path) dataset.load(load_path)
return dataset return dataset

View File

@ -36,8 +36,8 @@ def train(
max_grad_norm: float, max_grad_norm: float,
embdeding_lr_rate: int, embdeding_lr_rate: int,
random_seed: int, random_seed: int,
max_len: int, window_size: int,
step_size: int, stride: int,
resume_from_checkpoint: bool resume_from_checkpoint: bool
): ):
assert train_type in ["seq", "sft", "dpo"] assert train_type in ["seq", "sft", "dpo"]
@ -49,8 +49,8 @@ def train(
if isinstance(parameter, Checkpoint) and resume_from_checkpoint: if isinstance(parameter, Checkpoint) and resume_from_checkpoint:
checkpoint = parameter checkpoint = parameter
if max_len is None: if window_size is None:
max_len = parameter.config.m_len window_size = parameter.config.m_len
model = parameter.model model = parameter.model
device = torch.device("cuda") device = torch.device("cuda")
@ -74,8 +74,8 @@ def train(
dataset = DatasetLoader.load( dataset = DatasetLoader.load(
train_type=train_type, train_type=train_type,
load_path=cache_files, load_path=cache_files,
max_len=max_len, window_size=window_size,
step_size=step_size, stride=stride,
**kwargs **kwargs
) )
@ -140,8 +140,8 @@ if __name__ == "__main__":
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
# other configs # other configs
parser.add_argument("--max_len", type=int, default=None, help="the max length of the input sequence.") parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.")
parser.add_argument("--step_size", type=int, default=None, help="the step size of the input sequence.") parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.")
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.") parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.") parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.") parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")