feat(tools/train): 优化训练参数传递
This commit is contained in:
parent
3ee84b31a0
commit
db53cc5001
|
|
@ -249,10 +249,9 @@ class DatasetLoader:
|
|||
@staticmethod
|
||||
def load(
|
||||
train_type: Literal["seq", "sft", "dpo"],
|
||||
load_path: Union[str, List[str]],
|
||||
load_path: str,
|
||||
window_size: int,
|
||||
stride: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> BaseDataset:
|
||||
if stride is None:
|
||||
stride = window_size
|
||||
|
|
|
|||
|
|
@ -8,14 +8,39 @@ from khaosz.trainer import Trainer, StrategyFactory
|
|||
from khaosz.data import DatasetLoader
|
||||
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||
|
||||
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
|
||||
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
|
||||
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
|
||||
|
||||
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.")
|
||||
parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.")
|
||||
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
|
||||
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
|
||||
parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.")
|
||||
parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.")
|
||||
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
|
||||
parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.")
|
||||
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
||||
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
|
||||
parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory")
|
||||
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.")
|
||||
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.")
|
||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||
|
||||
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
|
||||
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
|
||||
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
||||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
||||
parser.add_argument("--resume_from_checkpoint", action="store_true", help="Train from checkpoint or not.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def get_files(root_path: str) -> list[str]:
|
||||
paths = []
|
||||
for root, _, files in os.walk(root_path):
|
||||
paths.extend([os.path.join(root, file) for file in files])
|
||||
|
||||
return paths
|
||||
return args
|
||||
|
||||
def train(
|
||||
train_type: str,
|
||||
|
|
@ -36,6 +61,8 @@ def train(
|
|||
max_grad_norm: float,
|
||||
embdeding_lr_rate: int,
|
||||
random_seed: int,
|
||||
num_workers: int,
|
||||
pin_memory: bool,
|
||||
window_size: int,
|
||||
stride: int,
|
||||
resume_from_checkpoint: bool
|
||||
|
|
@ -55,7 +82,6 @@ def train(
|
|||
model = parameter.model
|
||||
device = torch.device("cuda")
|
||||
model = model.to(device=device, dtype=torch.bfloat16)
|
||||
cache_files = get_files(data_root_path)
|
||||
|
||||
kwargs = {
|
||||
"dpo_beta": dpo_beta,
|
||||
|
|
@ -73,7 +99,7 @@ def train(
|
|||
|
||||
dataset = DatasetLoader.load(
|
||||
train_type=train_type,
|
||||
load_path=cache_files,
|
||||
load_path=data_root_path,
|
||||
window_size=window_size,
|
||||
stride=stride,
|
||||
**kwargs
|
||||
|
|
@ -103,8 +129,8 @@ def train(
|
|||
accumulation_steps=accumulation_steps,
|
||||
max_grad_norm=max_grad_norm,
|
||||
random_seed=random_seed,
|
||||
num_workers=4,
|
||||
pin_memory=True
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory
|
||||
)
|
||||
|
||||
schedule_config = CosineScheduleConfig(
|
||||
|
|
@ -121,32 +147,5 @@ def train(
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||
# train args
|
||||
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
|
||||
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
|
||||
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
|
||||
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.")
|
||||
parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.")
|
||||
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
|
||||
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
|
||||
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
|
||||
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
|
||||
parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.")
|
||||
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
|
||||
parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.")
|
||||
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
||||
|
||||
# other configs
|
||||
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.")
|
||||
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.")
|
||||
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
||||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
||||
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")
|
||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args = parse_args()
|
||||
train(**vars(args))
|
||||
Loading…
Reference in New Issue