refactor(data): 修改变量命名方式
This commit is contained in:
parent
6a3135f401
commit
5d3799b715
|
|
@ -81,11 +81,11 @@ class MutiSegmentFetcher:
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(Dataset, ABC):
|
class BaseDataset(Dataset, ABC):
|
||||||
def __init__(self, chunk_size: int, step_size: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.segments: MutiSeg = {}
|
self.segments: MutiSeg = {}
|
||||||
self.chunk_size = chunk_size
|
self.window_size = window_size
|
||||||
self.step_size = step_size
|
self.stride = stride
|
||||||
self.total_samples = None
|
self.total_samples = None
|
||||||
|
|
||||||
def save(self, save_path: str):
|
def save(self, save_path: str):
|
||||||
|
|
@ -106,8 +106,8 @@ class BaseDataset(Dataset, ABC):
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def get_index(self, index: int) -> int:
|
def get_index(self, index: int) -> int:
|
||||||
begin_idx = min(index * self.step_size, self.total_samples - self.chunk_size - 1)
|
begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1)
|
||||||
end_idx = begin_idx + self.chunk_size
|
end_idx = begin_idx + self.window_size
|
||||||
|
|
||||||
return begin_idx, end_idx
|
return begin_idx, end_idx
|
||||||
|
|
||||||
|
|
@ -117,14 +117,14 @@ class BaseDataset(Dataset, ABC):
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
assert self.total_samples is not None
|
assert self.total_samples is not None
|
||||||
if self.total_samples <= self.chunk_size:
|
if self.total_samples <= self.window_size:
|
||||||
return 0
|
return 0
|
||||||
return self.total_samples // self.step_size + 1
|
return self.total_samples // self.stride + 1
|
||||||
|
|
||||||
|
|
||||||
class SeqDataset(BaseDataset):
|
class SeqDataset(BaseDataset):
|
||||||
def __init__(self, chunk_size: int, step_size: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(chunk_size, step_size)
|
super().__init__(window_size, stride)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
|
|
@ -141,8 +141,8 @@ class SeqDataset(BaseDataset):
|
||||||
|
|
||||||
|
|
||||||
class SftDataset(BaseDataset):
|
class SftDataset(BaseDataset):
|
||||||
def __init__(self, chunk_size: int, step_size: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(chunk_size, step_size)
|
super().__init__(window_size, stride)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
|
|
@ -159,8 +159,8 @@ class SftDataset(BaseDataset):
|
||||||
|
|
||||||
|
|
||||||
class DpoDataset(BaseDataset):
|
class DpoDataset(BaseDataset):
|
||||||
def __init__(self, chunk_size: int, step_size: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(chunk_size, step_size)
|
super().__init__(window_size, stride)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
|
|
@ -178,8 +178,8 @@ class DpoDataset(BaseDataset):
|
||||||
|
|
||||||
|
|
||||||
class PpoDataset(BaseDataset):
|
class PpoDataset(BaseDataset):
|
||||||
def __init__(self, chunk_size: int, step_size: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(chunk_size, step_size)
|
super().__init__(window_size, stride)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
|
|
@ -201,19 +201,19 @@ class DatasetLoader:
|
||||||
def load(
|
def load(
|
||||||
train_type: Literal["seq", "sft", "dpo"],
|
train_type: Literal["seq", "sft", "dpo"],
|
||||||
load_path: Union[str, List[str]],
|
load_path: Union[str, List[str]],
|
||||||
max_len: int,
|
window_size: int,
|
||||||
step_size: Optional[int] = None,
|
stride: Optional[int] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> BaseDataset:
|
) -> BaseDataset:
|
||||||
if step_size is None:
|
if stride is None:
|
||||||
step_size = max_len
|
stride = window_size
|
||||||
|
|
||||||
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
|
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
|
||||||
"seq": lambda max_len: SeqDataset(max_len, step_size),
|
"seq": lambda window_size: SeqDataset(window_size, stride),
|
||||||
"sft": lambda max_len: SftDataset(max_len, step_size),
|
"sft": lambda window_size: SftDataset(window_size, stride),
|
||||||
"dpo": lambda max_len: DpoDataset(max_len, step_size),
|
"dpo": lambda window_size: DpoDataset(window_size, stride),
|
||||||
}
|
}
|
||||||
dataset = dataset_router[train_type](max_len)
|
dataset = dataset_router[train_type](window_size)
|
||||||
dataset.load(load_path)
|
dataset.load(load_path)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
|
||||||
16
train.py
16
train.py
|
|
@ -36,8 +36,8 @@ def train(
|
||||||
max_grad_norm: float,
|
max_grad_norm: float,
|
||||||
embdeding_lr_rate: int,
|
embdeding_lr_rate: int,
|
||||||
random_seed: int,
|
random_seed: int,
|
||||||
max_len: int,
|
window_size: int,
|
||||||
step_size: int,
|
stride: int,
|
||||||
resume_from_checkpoint: bool
|
resume_from_checkpoint: bool
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo"]
|
assert train_type in ["seq", "sft", "dpo"]
|
||||||
|
|
@ -49,8 +49,8 @@ def train(
|
||||||
if isinstance(parameter, Checkpoint) and resume_from_checkpoint:
|
if isinstance(parameter, Checkpoint) and resume_from_checkpoint:
|
||||||
checkpoint = parameter
|
checkpoint = parameter
|
||||||
|
|
||||||
if max_len is None:
|
if window_size is None:
|
||||||
max_len = parameter.config.m_len
|
window_size = parameter.config.m_len
|
||||||
|
|
||||||
model = parameter.model
|
model = parameter.model
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
|
@ -74,8 +74,8 @@ def train(
|
||||||
dataset = DatasetLoader.load(
|
dataset = DatasetLoader.load(
|
||||||
train_type=train_type,
|
train_type=train_type,
|
||||||
load_path=cache_files,
|
load_path=cache_files,
|
||||||
max_len=max_len,
|
window_size=window_size,
|
||||||
step_size=step_size,
|
stride=stride,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -140,8 +140,8 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
||||||
|
|
||||||
# other configs
|
# other configs
|
||||||
parser.add_argument("--max_len", type=int, default=None, help="the max length of the input sequence.")
|
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.")
|
||||||
parser.add_argument("--step_size", type=int, default=None, help="the step size of the input sequence.")
|
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.")
|
||||||
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
||||||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
||||||
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")
|
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue