diff --git a/README.md b/README.md index 0d5fa89..42744f3 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,8 @@ python train.py \ --n_epoch=5 \ --batch_size=8 \ --max_lr=2e-4 \ ---n_iter_ckpt=10000 \ ---ckpt_dir=checkpoints +--checkpoint_interval=10000 \ +--checkpoint_dir=checkpoints ``` **Parameters Explanation:** @@ -64,11 +64,11 @@ python train.py \ - `--param_path`: Path to the model training parameters - `--n_epoch`: Total number of training epochs - `--batch_size`: Batch size -- `--n_iter_step`: Number of batches per training step -- `--warning_step`: Number of warmup steps +- `--accumulation_steps`: Number of batches per training step +- `--warmup_steps`: Number of warmup steps - `--max_lr`: Maximum learning rate (using warmup + cosine decay) -- `--n_iter_ckpt`: Checkpoint saving interval -- `--ckpt_dir`: Directory to save checkpoints +- `--checkpoint_interval`: Checkpoint saving interval +- `--checkpoint_dir`: Directory to save checkpoints - `--resume_dir`: Resume training from the specified path Training logs will be saved in `train_log.txt`. Checkpoints will be saved in the specified directory for resuming training or evaluation. @@ -214,8 +214,8 @@ python train.py \ --n_epoch=5 \ --batch_size=8 \ --max_lr=2e-4 \ ---n_iter_ckpt=10000 \ ---ckpt_dir=checkpoints +--checkpoint_interval=10000 \ +--checkpoint_dir=checkpoints ``` **参数说明:** @@ -224,11 +224,11 @@ python train.py \ - `--param_path`: 模型训练参数路径 - `--n_epoch`: 总训练轮数 - `--batch_size`: 批量大小 -- `--n_iter_step`: 每个训练步骤的 batch 数量 -- `--warning_step`: 预热步数(warmup steps) +- `--accumulation_steps`: 每个训练步骤的 batch 数量 +- `--warmup_steps`: 预热步数(warmup steps) - `--max_lr`: 最大学习率(使用预热 + 余弦衰减) -- `--n_iter_ckpt`: 检查点保存间隔 -- `--ckpt_dir`: 检查点保存目录 +- `--checkpoint_interval`: 检查点保存间隔 +- `--checkpoint_dir`: 检查点保存目录 - `--resume_dir`: 从指定路径恢复训练 训练日志将保存在 `train_log.txt` 中。检查点将保存在指定目录,用于恢复训练或评估。 diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 696bf83..eadaf1f 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -190,7 +190,7 @@ class TrainConfig: default=None, metadata={"help": "Optimizer for training."} ) - ckpt_dir: str = field( + checkpoint_dir: str = field( default="./checkpoint", metadata={"help": "Checkpoint directory."} ) @@ -202,11 +202,11 @@ class TrainConfig: default=4, metadata={"help": "Batch size for training."} ) - n_iter_ckpt: int = field( + checkpoint_interval: int = field( default=5000, metadata={"help": "Number of iterations between checkpoints."} ) - n_iter_step: int = field( + accumulation_steps: int = field( default=1, metadata={"help": "Number of iterations between steps."} ) @@ -256,7 +256,7 @@ class ScheduleConfig(ABC): @dataclass class CosineScheduleConfig(ScheduleConfig): - total_steps: int = field( # 更准确的命名 + total_steps: int = field( default=None, metadata={"help": "Total training steps for cosine schedule."} ) diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 2af62e3..7750f54 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -32,7 +32,7 @@ class Trainer: train_config: TrainConfig ): current_iter = len(loss_list) - save_path = os.path.join(train_config.ckpt_dir, f"iter_{current_iter}") + save_path = os.path.join(train_config.checkpoint_dir, f"iter_{current_iter}") self.checkpoint.loss_list = loss_list self.checkpoint.optim_state = train_config.optimizer.state_dict() self.checkpoint.save(save_path) @@ -93,7 +93,7 @@ class Trainer: #backward loss.backward() #step - if current_iter % train_config.n_iter_step == 0: + if current_iter % train_config.accumulation_steps == 0: clip_grad_norm_( self.checkpoint.model.parameters(), train_config.max_grad_norm @@ -108,7 +108,7 @@ class Trainer: "lr": f"{train_config.optimizer.param_groups[0]['lr']:.2e}" }) #save checkpotint - if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt: + if current_iter - last_ckpt_iter >= train_config.checkpoint_interval: self.save_checkpoint(loss_list, train_config) last_ckpt_iter = current_iter diff --git a/train.py b/train.py index b7274fc..64805d5 100644 --- a/train.py +++ b/train.py @@ -21,13 +21,13 @@ def train( train_type: str, param_path: str, data_root_path: str, + max_lr: int, n_epoch: int, batch_size: int, - n_iter_step: int, - warning_step: int, - max_lr: int, - n_iter_ckpt: int, - ckpt_dir: str, + accumulation_steps: int, + warmup_steps: int, + checkpoint_interval: int, + checkpoint_dir: str, dpo_beta: float, adamw_betas: tuple, adamw_weight_decay: float, @@ -84,18 +84,18 @@ def train( strategy=strategy, dataset=dataset, optimizer=optim, - ckpt_dir=ckpt_dir, + checkpoint_dir=checkpoint_dir, n_epoch=n_epoch, batch_size=batch_size, - n_iter_ckpt=n_iter_ckpt, - n_iter_step=n_iter_step, + checkpoint_interval=checkpoint_interval, + accumulation_steps=accumulation_steps, max_grad_norm=max_grad_norm, random_seed=random_seed, ) schedule_config = CosineScheduleConfig( - warning_step=warning_step, - total_iters=len(dataset) * n_epoch // batch_size, + warmup_steps=warmup_steps, + total_steps=len(dataset) * n_epoch // batch_size, ) trainer = Trainer( @@ -114,11 +114,11 @@ if __name__ == "__main__": parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.") parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.") parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.") - parser.add_argument("--n_iter_step", type=int, default=1, help="Number of iterations between each optimizer step.") - parser.add_argument("--warning_step", type=int, default=1000, help="Number of iters between warnings.") + parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.") + parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.") parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.") - parser.add_argument("--n_iter_ckpt", type=int, default=5000, help="Number of iters between checkpoints.") - parser.add_argument("--ckpt_dir", type=str, default="checkpoint", help="Directory to save checkpoints.") + parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.") + parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.") parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.") parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.") @@ -136,16 +136,16 @@ if __name__ == "__main__": data_root_path=args.data_root_path, n_epoch=args.n_epoch, batch_size=args.batch_size, - n_iter_step=args.n_iter_step, - warning_step=args.warning_step, + accumulation_steps=args.accumulation_steps, + warmup_steps=args.warmup_steps, max_lr=args.max_lr, dpo_beta=args.dpo_beta, adamw_betas=args.adamw_betas, adamw_weight_decay=args.adamw_weight_decay, max_grad_norm=args.max_grad_norm, embdeding_lr_rate=args.embdeding_lr_rate, - n_iter_ckpt=args.n_iter_ckpt, - ckpt_dir=args.ckpt_dir, + checkpoint_interval=args.checkpoint_interval, + checkpoint_dir=args.checkpoint_dir, train_type=args.train_type, random_seed=args.random_seed, multi_turn=args.multi_turn