diff --git a/tools/train.py b/tools/train.py index bf3287f..1186206 100644 --- a/tools/train.py +++ b/tools/train.py @@ -35,7 +35,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.") parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.") parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") - parser.add_argument("--label_smoothing", type=int, default=0.1, help="cross_entropy function label smoothing parameter") + parser.add_argument("--label_smoothing", type=float, default=0.1, help="cross_entropy function label smoothing parameter") parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.") parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")