feat: 增加 label smoothing 设置

This commit is contained in:
ViperEkura 2026-03-13 22:37:27 +08:00
parent 6d6ef6dbb6
commit e35cb0d84a
1 changed files with 3 additions and 12 deletions

View File

@ -5,7 +5,6 @@ import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from typing import List, Optional
from functools import partial
from khaosz.data import DatasetLoader
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
@ -14,14 +13,6 @@ from khaosz.parallel import get_rank
def parse_args() -> argparse.Namespace:
def parse_device_ids(s: Optional[str]) -> Optional[List[int]]:
if s is None or s.strip() == "":
return None
try:
return [int(x.strip()) for x in s.split(",") if x.strip()]
except ValueError as e:
raise argparse.ArgumentTypeError(f"Invalid device_ids format: {s}. Expected comma-separated integers like '0,1,2'.")
parser = argparse.ArgumentParser(description="Train the Transformer model.")
@ -44,6 +35,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.")
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.")
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
parser.add_argument("--label_smoothing", type=int, default=0.1, help="cross_entropy function label smoothing parameter")
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
@ -51,7 +43,6 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
parser.add_argument("--device_ids", type=parse_device_ids, default=None, help="Device IDs to use.")
parser.add_argument("--device_type", type=str, default="cuda", help="Device type to use.")
args = parser.parse_args()
@ -101,13 +92,13 @@ def train(
adamw_beta2: float,
adamw_weight_decay: float,
max_grad_norm: float,
label_smoothing: float,
random_seed: int,
num_workers: int,
pin_memory: bool,
window_size: int,
stride: int,
nprocs: int,
device_ids: List[int],
device_type: str,
):
assert train_type in ["seq", "sft", "dpo"]
@ -126,6 +117,7 @@ def train(
"bos_token_id": parameter.tokenizer.bos_id,
"eos_token_id": parameter.tokenizer.eos_id,
"pad_token_id": parameter.tokenizer.pad_id,
"label_smoothing": label_smoothing
}
dataset = DatasetLoader.load(
@ -165,7 +157,6 @@ def train(
nprocs=nprocs,
parallel_wrapper=ddp_wrap,
state_dict_fn=prepare_checkpoint,
device_ids=device_ids,
device_type=device_type,
extra_kwargs=kwargs,
)