refactor(trainer): 统一参数命名以提升可读性
This commit is contained in:
parent
fa43ed2943
commit
1c9063fd3d
24
README.md
24
README.md
|
|
@ -54,8 +54,8 @@ python train.py \
|
||||||
--n_epoch=5 \
|
--n_epoch=5 \
|
||||||
--batch_size=8 \
|
--batch_size=8 \
|
||||||
--max_lr=2e-4 \
|
--max_lr=2e-4 \
|
||||||
--n_iter_ckpt=10000 \
|
--checkpoint_interval=10000 \
|
||||||
--ckpt_dir=checkpoints
|
--checkpoint_dir=checkpoints
|
||||||
```
|
```
|
||||||
|
|
||||||
**Parameters Explanation:**
|
**Parameters Explanation:**
|
||||||
|
|
@ -64,11 +64,11 @@ python train.py \
|
||||||
- `--param_path`: Path to the model training parameters
|
- `--param_path`: Path to the model training parameters
|
||||||
- `--n_epoch`: Total number of training epochs
|
- `--n_epoch`: Total number of training epochs
|
||||||
- `--batch_size`: Batch size
|
- `--batch_size`: Batch size
|
||||||
- `--n_iter_step`: Number of batches per training step
|
- `--accumulation_steps`: Number of batches per training step
|
||||||
- `--warning_step`: Number of warmup steps
|
- `--warmup_steps`: Number of warmup steps
|
||||||
- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
|
- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
|
||||||
- `--n_iter_ckpt`: Checkpoint saving interval
|
- `--checkpoint_interval`: Checkpoint saving interval
|
||||||
- `--ckpt_dir`: Directory to save checkpoints
|
- `--checkpoint_dir`: Directory to save checkpoints
|
||||||
- `--resume_dir`: Resume training from the specified path
|
- `--resume_dir`: Resume training from the specified path
|
||||||
|
|
||||||
Training logs will be saved in `train_log.txt`. Checkpoints will be saved in the specified directory for resuming training or evaluation.
|
Training logs will be saved in `train_log.txt`. Checkpoints will be saved in the specified directory for resuming training or evaluation.
|
||||||
|
|
@ -214,8 +214,8 @@ python train.py \
|
||||||
--n_epoch=5 \
|
--n_epoch=5 \
|
||||||
--batch_size=8 \
|
--batch_size=8 \
|
||||||
--max_lr=2e-4 \
|
--max_lr=2e-4 \
|
||||||
--n_iter_ckpt=10000 \
|
--checkpoint_interval=10000 \
|
||||||
--ckpt_dir=checkpoints
|
--checkpoint_dir=checkpoints
|
||||||
```
|
```
|
||||||
|
|
||||||
**参数说明:**
|
**参数说明:**
|
||||||
|
|
@ -224,11 +224,11 @@ python train.py \
|
||||||
- `--param_path`: 模型训练参数路径
|
- `--param_path`: 模型训练参数路径
|
||||||
- `--n_epoch`: 总训练轮数
|
- `--n_epoch`: 总训练轮数
|
||||||
- `--batch_size`: 批量大小
|
- `--batch_size`: 批量大小
|
||||||
- `--n_iter_step`: 每个训练步骤的 batch 数量
|
- `--accumulation_steps`: 每个训练步骤的 batch 数量
|
||||||
- `--warning_step`: 预热步数(warmup steps)
|
- `--warmup_steps`: 预热步数(warmup steps)
|
||||||
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
|
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
|
||||||
- `--n_iter_ckpt`: 检查点保存间隔
|
- `--checkpoint_interval`: 检查点保存间隔
|
||||||
- `--ckpt_dir`: 检查点保存目录
|
- `--checkpoint_dir`: 检查点保存目录
|
||||||
- `--resume_dir`: 从指定路径恢复训练
|
- `--resume_dir`: 从指定路径恢复训练
|
||||||
|
|
||||||
训练日志将保存在 `train_log.txt` 中。检查点将保存在指定目录,用于恢复训练或评估。
|
训练日志将保存在 `train_log.txt` 中。检查点将保存在指定目录,用于恢复训练或评估。
|
||||||
|
|
|
||||||
|
|
@ -190,7 +190,7 @@ class TrainConfig:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Optimizer for training."}
|
metadata={"help": "Optimizer for training."}
|
||||||
)
|
)
|
||||||
ckpt_dir: str = field(
|
checkpoint_dir: str = field(
|
||||||
default="./checkpoint",
|
default="./checkpoint",
|
||||||
metadata={"help": "Checkpoint directory."}
|
metadata={"help": "Checkpoint directory."}
|
||||||
)
|
)
|
||||||
|
|
@ -202,11 +202,11 @@ class TrainConfig:
|
||||||
default=4,
|
default=4,
|
||||||
metadata={"help": "Batch size for training."}
|
metadata={"help": "Batch size for training."}
|
||||||
)
|
)
|
||||||
n_iter_ckpt: int = field(
|
checkpoint_interval: int = field(
|
||||||
default=5000,
|
default=5000,
|
||||||
metadata={"help": "Number of iterations between checkpoints."}
|
metadata={"help": "Number of iterations between checkpoints."}
|
||||||
)
|
)
|
||||||
n_iter_step: int = field(
|
accumulation_steps: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "Number of iterations between steps."}
|
metadata={"help": "Number of iterations between steps."}
|
||||||
)
|
)
|
||||||
|
|
@ -256,7 +256,7 @@ class ScheduleConfig(ABC):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CosineScheduleConfig(ScheduleConfig):
|
class CosineScheduleConfig(ScheduleConfig):
|
||||||
total_steps: int = field( # 更准确的命名
|
total_steps: int = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Total training steps for cosine schedule."}
|
metadata={"help": "Total training steps for cosine schedule."}
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ class Trainer:
|
||||||
train_config: TrainConfig
|
train_config: TrainConfig
|
||||||
):
|
):
|
||||||
current_iter = len(loss_list)
|
current_iter = len(loss_list)
|
||||||
save_path = os.path.join(train_config.ckpt_dir, f"iter_{current_iter}")
|
save_path = os.path.join(train_config.checkpoint_dir, f"iter_{current_iter}")
|
||||||
self.checkpoint.loss_list = loss_list
|
self.checkpoint.loss_list = loss_list
|
||||||
self.checkpoint.optim_state = train_config.optimizer.state_dict()
|
self.checkpoint.optim_state = train_config.optimizer.state_dict()
|
||||||
self.checkpoint.save(save_path)
|
self.checkpoint.save(save_path)
|
||||||
|
|
@ -93,7 +93,7 @@ class Trainer:
|
||||||
#backward
|
#backward
|
||||||
loss.backward()
|
loss.backward()
|
||||||
#step
|
#step
|
||||||
if current_iter % train_config.n_iter_step == 0:
|
if current_iter % train_config.accumulation_steps == 0:
|
||||||
clip_grad_norm_(
|
clip_grad_norm_(
|
||||||
self.checkpoint.model.parameters(),
|
self.checkpoint.model.parameters(),
|
||||||
train_config.max_grad_norm
|
train_config.max_grad_norm
|
||||||
|
|
@ -108,7 +108,7 @@ class Trainer:
|
||||||
"lr": f"{train_config.optimizer.param_groups[0]['lr']:.2e}"
|
"lr": f"{train_config.optimizer.param_groups[0]['lr']:.2e}"
|
||||||
})
|
})
|
||||||
#save checkpotint
|
#save checkpotint
|
||||||
if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt:
|
if current_iter - last_ckpt_iter >= train_config.checkpoint_interval:
|
||||||
self.save_checkpoint(loss_list, train_config)
|
self.save_checkpoint(loss_list, train_config)
|
||||||
last_ckpt_iter = current_iter
|
last_ckpt_iter = current_iter
|
||||||
|
|
||||||
|
|
|
||||||
36
train.py
36
train.py
|
|
@ -21,13 +21,13 @@ def train(
|
||||||
train_type: str,
|
train_type: str,
|
||||||
param_path: str,
|
param_path: str,
|
||||||
data_root_path: str,
|
data_root_path: str,
|
||||||
|
max_lr: int,
|
||||||
n_epoch: int,
|
n_epoch: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
n_iter_step: int,
|
accumulation_steps: int,
|
||||||
warning_step: int,
|
warmup_steps: int,
|
||||||
max_lr: int,
|
checkpoint_interval: int,
|
||||||
n_iter_ckpt: int,
|
checkpoint_dir: str,
|
||||||
ckpt_dir: str,
|
|
||||||
dpo_beta: float,
|
dpo_beta: float,
|
||||||
adamw_betas: tuple,
|
adamw_betas: tuple,
|
||||||
adamw_weight_decay: float,
|
adamw_weight_decay: float,
|
||||||
|
|
@ -84,18 +84,18 @@ def train(
|
||||||
strategy=strategy,
|
strategy=strategy,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
optimizer=optim,
|
optimizer=optim,
|
||||||
ckpt_dir=ckpt_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
n_epoch=n_epoch,
|
n_epoch=n_epoch,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter_ckpt=n_iter_ckpt,
|
checkpoint_interval=checkpoint_interval,
|
||||||
n_iter_step=n_iter_step,
|
accumulation_steps=accumulation_steps,
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(
|
||||||
warning_step=warning_step,
|
warmup_steps=warmup_steps,
|
||||||
total_iters=len(dataset) * n_epoch // batch_size,
|
total_steps=len(dataset) * n_epoch // batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
|
@ -114,11 +114,11 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
|
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
|
||||||
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.")
|
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.")
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
|
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
|
||||||
parser.add_argument("--n_iter_step", type=int, default=1, help="Number of iterations between each optimizer step.")
|
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.")
|
||||||
parser.add_argument("--warning_step", type=int, default=1000, help="Number of iters between warnings.")
|
parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.")
|
||||||
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
|
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
|
||||||
parser.add_argument("--n_iter_ckpt", type=int, default=5000, help="Number of iters between checkpoints.")
|
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
|
||||||
parser.add_argument("--ckpt_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
|
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
|
||||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||||
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
|
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
|
||||||
parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.")
|
parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.")
|
||||||
|
|
@ -136,16 +136,16 @@ if __name__ == "__main__":
|
||||||
data_root_path=args.data_root_path,
|
data_root_path=args.data_root_path,
|
||||||
n_epoch=args.n_epoch,
|
n_epoch=args.n_epoch,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
n_iter_step=args.n_iter_step,
|
accumulation_steps=args.accumulation_steps,
|
||||||
warning_step=args.warning_step,
|
warmup_steps=args.warmup_steps,
|
||||||
max_lr=args.max_lr,
|
max_lr=args.max_lr,
|
||||||
dpo_beta=args.dpo_beta,
|
dpo_beta=args.dpo_beta,
|
||||||
adamw_betas=args.adamw_betas,
|
adamw_betas=args.adamw_betas,
|
||||||
adamw_weight_decay=args.adamw_weight_decay,
|
adamw_weight_decay=args.adamw_weight_decay,
|
||||||
max_grad_norm=args.max_grad_norm,
|
max_grad_norm=args.max_grad_norm,
|
||||||
embdeding_lr_rate=args.embdeding_lr_rate,
|
embdeding_lr_rate=args.embdeding_lr_rate,
|
||||||
n_iter_ckpt=args.n_iter_ckpt,
|
checkpoint_interval=args.checkpoint_interval,
|
||||||
ckpt_dir=args.ckpt_dir,
|
checkpoint_dir=args.checkpoint_dir,
|
||||||
train_type=args.train_type,
|
train_type=args.train_type,
|
||||||
random_seed=args.random_seed,
|
random_seed=args.random_seed,
|
||||||
multi_turn=args.multi_turn
|
multi_turn=args.multi_turn
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue