refactor(data): 修改变量命名方式

This commit is contained in:
ViperEkura 2025-10-30 16:32:25 +08:00
parent 6a3135f401
commit 5d3799b715
2 changed files with 31 additions and 31 deletions

View File

@ -81,11 +81,11 @@ class MutiSegmentFetcher:
class BaseDataset(Dataset, ABC):
def __init__(self, chunk_size: int, step_size: int):
def __init__(self, window_size: int, stride: int):
super().__init__()
self.segments: MutiSeg = {}
self.chunk_size = chunk_size
self.step_size = step_size
self.window_size = window_size
self.stride = stride
self.total_samples = None
def save(self, save_path: str):
@ -106,8 +106,8 @@ class BaseDataset(Dataset, ABC):
self.fetcher = MutiSegmentFetcher(self.segments)
def get_index(self, index: int) -> int:
begin_idx = min(index * self.step_size, self.total_samples - self.chunk_size - 1)
end_idx = begin_idx + self.chunk_size
begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1)
end_idx = begin_idx + self.window_size
return begin_idx, end_idx
@ -117,14 +117,14 @@ class BaseDataset(Dataset, ABC):
def __len__(self) -> int:
assert self.total_samples is not None
if self.total_samples <= self.chunk_size:
if self.total_samples <= self.window_size:
return 0
return self.total_samples // self.step_size + 1
return self.total_samples // self.stride + 1
class SeqDataset(BaseDataset):
def __init__(self, chunk_size: int, step_size: int):
super().__init__(chunk_size, step_size)
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
@ -141,8 +141,8 @@ class SeqDataset(BaseDataset):
class SftDataset(BaseDataset):
def __init__(self, chunk_size: int, step_size: int):
super().__init__(chunk_size, step_size)
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
@ -159,8 +159,8 @@ class SftDataset(BaseDataset):
class DpoDataset(BaseDataset):
def __init__(self, chunk_size: int, step_size: int):
super().__init__(chunk_size, step_size)
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
@ -178,8 +178,8 @@ class DpoDataset(BaseDataset):
class PpoDataset(BaseDataset):
def __init__(self, chunk_size: int, step_size: int):
super().__init__(chunk_size, step_size)
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
@ -201,19 +201,19 @@ class DatasetLoader:
def load(
train_type: Literal["seq", "sft", "dpo"],
load_path: Union[str, List[str]],
max_len: int,
step_size: Optional[int] = None,
window_size: int,
stride: Optional[int] = None,
**kwargs
) -> BaseDataset:
if step_size is None:
step_size = max_len
if stride is None:
stride = window_size
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
"seq": lambda max_len: SeqDataset(max_len, step_size),
"sft": lambda max_len: SftDataset(max_len, step_size),
"dpo": lambda max_len: DpoDataset(max_len, step_size),
"seq": lambda window_size: SeqDataset(window_size, stride),
"sft": lambda window_size: SftDataset(window_size, stride),
"dpo": lambda window_size: DpoDataset(window_size, stride),
}
dataset = dataset_router[train_type](max_len)
dataset = dataset_router[train_type](window_size)
dataset.load(load_path)
return dataset

View File

@ -36,8 +36,8 @@ def train(
max_grad_norm: float,
embdeding_lr_rate: int,
random_seed: int,
max_len: int,
step_size: int,
window_size: int,
stride: int,
resume_from_checkpoint: bool
):
assert train_type in ["seq", "sft", "dpo"]
@ -49,8 +49,8 @@ def train(
if isinstance(parameter, Checkpoint) and resume_from_checkpoint:
checkpoint = parameter
if max_len is None:
max_len = parameter.config.m_len
if window_size is None:
window_size = parameter.config.m_len
model = parameter.model
device = torch.device("cuda")
@ -74,8 +74,8 @@ def train(
dataset = DatasetLoader.load(
train_type=train_type,
load_path=cache_files,
max_len=max_len,
step_size=step_size,
window_size=window_size,
stride=stride,
**kwargs
)
@ -140,8 +140,8 @@ if __name__ == "__main__":
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
# other configs
parser.add_argument("--max_len", type=int, default=None, help="the max length of the input sequence.")
parser.add_argument("--step_size", type=int, default=None, help="the step size of the input sequence.")
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.")
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.")
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")