From e35cb0d84afa899f3f8ad00f8c0d928baed5dba3 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 13 Mar 2026 22:37:27 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=20label=20smoothing?= =?UTF-8?q?=20=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/train.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/tools/train.py b/tools/train.py index 387649a..ec22690 100644 --- a/tools/train.py +++ b/tools/train.py @@ -5,7 +5,6 @@ import torch.nn as nn import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP -from typing import List, Optional from functools import partial from khaosz.data import DatasetLoader from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig @@ -14,14 +13,6 @@ from khaosz.parallel import get_rank def parse_args() -> argparse.Namespace: - def parse_device_ids(s: Optional[str]) -> Optional[List[int]]: - if s is None or s.strip() == "": - return None - try: - return [int(x.strip()) for x in s.split(",") if x.strip()] - except ValueError as e: - raise argparse.ArgumentTypeError(f"Invalid device_ids format: {s}. Expected comma-separated integers like '0,1,2'.") - parser = argparse.ArgumentParser(description="Train the Transformer model.") @@ -44,6 +35,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.") parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.") parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") + parser.add_argument("--label_smoothing", type=int, default=0.1, help="cross_entropy function label smoothing parameter") parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.") parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.") @@ -51,7 +43,6 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.") parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.") - parser.add_argument("--device_ids", type=parse_device_ids, default=None, help="Device IDs to use.") parser.add_argument("--device_type", type=str, default="cuda", help="Device type to use.") args = parser.parse_args() @@ -101,13 +92,13 @@ def train( adamw_beta2: float, adamw_weight_decay: float, max_grad_norm: float, + label_smoothing: float, random_seed: int, num_workers: int, pin_memory: bool, window_size: int, stride: int, nprocs: int, - device_ids: List[int], device_type: str, ): assert train_type in ["seq", "sft", "dpo"] @@ -126,6 +117,7 @@ def train( "bos_token_id": parameter.tokenizer.bos_id, "eos_token_id": parameter.tokenizer.eos_id, "pad_token_id": parameter.tokenizer.pad_id, + "label_smoothing": label_smoothing } dataset = DatasetLoader.load( @@ -165,7 +157,6 @@ def train( nprocs=nprocs, parallel_wrapper=ddp_wrap, state_dict_fn=prepare_checkpoint, - device_ids=device_ids, device_type=device_type, extra_kwargs=kwargs, )