refactor(data): 修改变量命名方式
This commit is contained in:
parent
6a3135f401
commit
5d3799b715
|
|
@ -81,11 +81,11 @@ class MutiSegmentFetcher:
|
|||
|
||||
|
||||
class BaseDataset(Dataset, ABC):
|
||||
def __init__(self, chunk_size: int, step_size: int):
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__()
|
||||
self.segments: MutiSeg = {}
|
||||
self.chunk_size = chunk_size
|
||||
self.step_size = step_size
|
||||
self.window_size = window_size
|
||||
self.stride = stride
|
||||
self.total_samples = None
|
||||
|
||||
def save(self, save_path: str):
|
||||
|
|
@ -106,8 +106,8 @@ class BaseDataset(Dataset, ABC):
|
|||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
|
||||
def get_index(self, index: int) -> int:
|
||||
begin_idx = min(index * self.step_size, self.total_samples - self.chunk_size - 1)
|
||||
end_idx = begin_idx + self.chunk_size
|
||||
begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1)
|
||||
end_idx = begin_idx + self.window_size
|
||||
|
||||
return begin_idx, end_idx
|
||||
|
||||
|
|
@ -117,14 +117,14 @@ class BaseDataset(Dataset, ABC):
|
|||
|
||||
def __len__(self) -> int:
|
||||
assert self.total_samples is not None
|
||||
if self.total_samples <= self.chunk_size:
|
||||
if self.total_samples <= self.window_size:
|
||||
return 0
|
||||
return self.total_samples // self.step_size + 1
|
||||
return self.total_samples // self.stride + 1
|
||||
|
||||
|
||||
class SeqDataset(BaseDataset):
|
||||
def __init__(self, chunk_size: int, step_size: int):
|
||||
super().__init__(chunk_size, step_size)
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
|
|
@ -141,8 +141,8 @@ class SeqDataset(BaseDataset):
|
|||
|
||||
|
||||
class SftDataset(BaseDataset):
|
||||
def __init__(self, chunk_size: int, step_size: int):
|
||||
super().__init__(chunk_size, step_size)
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
|
|
@ -159,8 +159,8 @@ class SftDataset(BaseDataset):
|
|||
|
||||
|
||||
class DpoDataset(BaseDataset):
|
||||
def __init__(self, chunk_size: int, step_size: int):
|
||||
super().__init__(chunk_size, step_size)
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
|
|
@ -178,8 +178,8 @@ class DpoDataset(BaseDataset):
|
|||
|
||||
|
||||
class PpoDataset(BaseDataset):
|
||||
def __init__(self, chunk_size: int, step_size: int):
|
||||
super().__init__(chunk_size, step_size)
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
|
|
@ -201,19 +201,19 @@ class DatasetLoader:
|
|||
def load(
|
||||
train_type: Literal["seq", "sft", "dpo"],
|
||||
load_path: Union[str, List[str]],
|
||||
max_len: int,
|
||||
step_size: Optional[int] = None,
|
||||
window_size: int,
|
||||
stride: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> BaseDataset:
|
||||
if step_size is None:
|
||||
step_size = max_len
|
||||
if stride is None:
|
||||
stride = window_size
|
||||
|
||||
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
|
||||
"seq": lambda max_len: SeqDataset(max_len, step_size),
|
||||
"sft": lambda max_len: SftDataset(max_len, step_size),
|
||||
"dpo": lambda max_len: DpoDataset(max_len, step_size),
|
||||
"seq": lambda window_size: SeqDataset(window_size, stride),
|
||||
"sft": lambda window_size: SftDataset(window_size, stride),
|
||||
"dpo": lambda window_size: DpoDataset(window_size, stride),
|
||||
}
|
||||
dataset = dataset_router[train_type](max_len)
|
||||
dataset = dataset_router[train_type](window_size)
|
||||
dataset.load(load_path)
|
||||
|
||||
return dataset
|
||||
|
|
|
|||
16
train.py
16
train.py
|
|
@ -36,8 +36,8 @@ def train(
|
|||
max_grad_norm: float,
|
||||
embdeding_lr_rate: int,
|
||||
random_seed: int,
|
||||
max_len: int,
|
||||
step_size: int,
|
||||
window_size: int,
|
||||
stride: int,
|
||||
resume_from_checkpoint: bool
|
||||
):
|
||||
assert train_type in ["seq", "sft", "dpo"]
|
||||
|
|
@ -49,8 +49,8 @@ def train(
|
|||
if isinstance(parameter, Checkpoint) and resume_from_checkpoint:
|
||||
checkpoint = parameter
|
||||
|
||||
if max_len is None:
|
||||
max_len = parameter.config.m_len
|
||||
if window_size is None:
|
||||
window_size = parameter.config.m_len
|
||||
|
||||
model = parameter.model
|
||||
device = torch.device("cuda")
|
||||
|
|
@ -74,8 +74,8 @@ def train(
|
|||
dataset = DatasetLoader.load(
|
||||
train_type=train_type,
|
||||
load_path=cache_files,
|
||||
max_len=max_len,
|
||||
step_size=step_size,
|
||||
window_size=window_size,
|
||||
stride=stride,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
|
@ -140,8 +140,8 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
||||
|
||||
# other configs
|
||||
parser.add_argument("--max_len", type=int, default=None, help="the max length of the input sequence.")
|
||||
parser.add_argument("--step_size", type=int, default=None, help="the step size of the input sequence.")
|
||||
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.")
|
||||
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.")
|
||||
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
||||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
||||
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")
|
||||
|
|
|
|||
Loading…
Reference in New Issue