feat(tools/train): 优化训练参数传递

This commit is contained in:
ViperEkura 2025-11-30 13:49:24 +08:00
parent 3ee84b31a0
commit db53cc5001
2 changed files with 39 additions and 41 deletions

View File

@ -249,10 +249,9 @@ class DatasetLoader:
@staticmethod
def load(
train_type: Literal["seq", "sft", "dpo"],
load_path: Union[str, List[str]],
load_path: str,
window_size: int,
stride: Optional[int] = None,
**kwargs
) -> BaseDataset:
if stride is None:
stride = window_size

View File

@ -8,14 +8,39 @@ from khaosz.trainer import Trainer, StrategyFactory
from khaosz.data import DatasetLoader
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train the Transformer model.")
def get_files(root_path: str) -> list[str]:
paths = []
for root, _, files in os.walk(root_path):
paths.extend([os.path.join(root, file) for file in files])
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
return paths
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.")
parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.")
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Beta values for AdamW optimizer.")
parser.add_argument("--adamw_beta2", type=float, default=0.95, help="Beta values for AdamW optimizer.")
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.")
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
parser.add_argument("--no_pin_memory", action="store_false", dest="pin_memory", help="Disable pin memory")
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.")
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.")
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
parser.add_argument("--resume_from_checkpoint", action="store_true", help="Train from checkpoint or not.")
args = parser.parse_args()
return args
def train(
train_type: str,
@ -36,6 +61,8 @@ def train(
max_grad_norm: float,
embdeding_lr_rate: int,
random_seed: int,
num_workers: int,
pin_memory: bool,
window_size: int,
stride: int,
resume_from_checkpoint: bool
@ -55,7 +82,6 @@ def train(
model = parameter.model
device = torch.device("cuda")
model = model.to(device=device, dtype=torch.bfloat16)
cache_files = get_files(data_root_path)
kwargs = {
"dpo_beta": dpo_beta,
@ -73,7 +99,7 @@ def train(
dataset = DatasetLoader.load(
train_type=train_type,
load_path=cache_files,
load_path=data_root_path,
window_size=window_size,
stride=stride,
**kwargs
@ -103,8 +129,8 @@ def train(
accumulation_steps=accumulation_steps,
max_grad_norm=max_grad_norm,
random_seed=random_seed,
num_workers=4,
pin_memory=True
num_workers=num_workers,
pin_memory=pin_memory
)
schedule_config = CosineScheduleConfig(
@ -121,32 +147,5 @@ def train(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train the Transformer model.")
# train args
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.")
parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.")
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.")
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.")
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
# other configs
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.")
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.")
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
args = parser.parse_args()
args = parse_args()
train(**vars(args))