diff --git a/khaosz/data/dataset.py b/khaosz/data/dataset.py index 5219e16..4a18373 100644 --- a/khaosz/data/dataset.py +++ b/khaosz/data/dataset.py @@ -249,10 +249,9 @@ class DatasetLoader: @staticmethod def load( train_type: Literal["seq", "sft", "dpo"], - load_path: Union[str, List[str]], + load_path: str, window_size: int, stride: Optional[int] = None, - **kwargs ) -> BaseDataset: if stride is None: stride = window_size diff --git a/tools/train.py b/tools/train.py index b87a9c5..d4bb15d 100644 --- a/tools/train.py +++ b/tools/train.py @@ -8,14 +8,39 @@ from khaosz.trainer import Trainer, StrategyFactory from khaosz.data import DatasetLoader -PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Train the Transformer model.") + + parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.") + parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.") + parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.") + + parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.") + parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.") + parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.") + parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.") + parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.") + parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.") + parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.") + parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.") + parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.") + parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") + parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.") + parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory") + parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.") + parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.") + parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") + + parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.") + parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.") + parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.") + parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.") + parser.add_argument("--resume_from_checkpoint", action="store_true", help="Train from checkpoint or not.") + + args = parser.parse_args() -def get_files(root_path: str) -> list[str]: - paths = [] - for root, _, files in os.walk(root_path): - paths.extend([os.path.join(root, file) for file in files]) - - return paths + return args def train( train_type: str, @@ -36,6 +61,8 @@ def train( max_grad_norm: float, embdeding_lr_rate: int, random_seed: int, + num_workers: int, + pin_memory: bool, window_size: int, stride: int, resume_from_checkpoint: bool @@ -55,7 +82,6 @@ def train( model = parameter.model device = torch.device("cuda") model = model.to(device=device, dtype=torch.bfloat16) - cache_files = get_files(data_root_path) kwargs = { "dpo_beta": dpo_beta, @@ -73,7 +99,7 @@ def train( dataset = DatasetLoader.load( train_type=train_type, - load_path=cache_files, + load_path=data_root_path, window_size=window_size, stride=stride, **kwargs @@ -103,8 +129,8 @@ def train( accumulation_steps=accumulation_steps, max_grad_norm=max_grad_norm, random_seed=random_seed, - num_workers=4, - pin_memory=True + num_workers=num_workers, + pin_memory=pin_memory ) schedule_config = CosineScheduleConfig( @@ -121,32 +147,5 @@ def train( if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Train the Transformer model.") - # train args - parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.") - parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.") - parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.") - parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.") - parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.") - parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.") - parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.") - parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.") - parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.") - parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.") - parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.") - parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.") - parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.") - parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.") - parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") - - # other configs - parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.") - parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.") - parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.") - parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.") - parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.") - parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") - - args = parser.parse_args() - + args = parse_args() train(**vars(args)) \ No newline at end of file