feat: 增加 label smoothing 设置
This commit is contained in:
parent
6d6ef6dbb6
commit
e35cb0d84a
|
|
@ -5,7 +5,6 @@ import torch.nn as nn
|
|||
import torch.optim as optim
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from typing import List, Optional
|
||||
from functools import partial
|
||||
from khaosz.data import DatasetLoader
|
||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||
|
|
@ -14,14 +13,6 @@ from khaosz.parallel import get_rank
|
|||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
def parse_device_ids(s: Optional[str]) -> Optional[List[int]]:
|
||||
if s is None or s.strip() == "":
|
||||
return None
|
||||
try:
|
||||
return [int(x.strip()) for x in s.split(",") if x.strip()]
|
||||
except ValueError as e:
|
||||
raise argparse.ArgumentTypeError(f"Invalid device_ids format: {s}. Expected comma-separated integers like '0,1,2'.")
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||
|
||||
|
|
@ -44,6 +35,7 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.")
|
||||
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.")
|
||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||
parser.add_argument("--label_smoothing", type=int, default=0.1, help="cross_entropy function label smoothing parameter")
|
||||
|
||||
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
|
||||
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
|
||||
|
|
@ -51,7 +43,6 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
||||
|
||||
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
||||
parser.add_argument("--device_ids", type=parse_device_ids, default=None, help="Device IDs to use.")
|
||||
parser.add_argument("--device_type", type=str, default="cuda", help="Device type to use.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
@ -101,13 +92,13 @@ def train(
|
|||
adamw_beta2: float,
|
||||
adamw_weight_decay: float,
|
||||
max_grad_norm: float,
|
||||
label_smoothing: float,
|
||||
random_seed: int,
|
||||
num_workers: int,
|
||||
pin_memory: bool,
|
||||
window_size: int,
|
||||
stride: int,
|
||||
nprocs: int,
|
||||
device_ids: List[int],
|
||||
device_type: str,
|
||||
):
|
||||
assert train_type in ["seq", "sft", "dpo"]
|
||||
|
|
@ -126,6 +117,7 @@ def train(
|
|||
"bos_token_id": parameter.tokenizer.bos_id,
|
||||
"eos_token_id": parameter.tokenizer.eos_id,
|
||||
"pad_token_id": parameter.tokenizer.pad_id,
|
||||
"label_smoothing": label_smoothing
|
||||
}
|
||||
|
||||
dataset = DatasetLoader.load(
|
||||
|
|
@ -165,7 +157,6 @@ def train(
|
|||
nprocs=nprocs,
|
||||
parallel_wrapper=ddp_wrap,
|
||||
state_dict_fn=prepare_checkpoint,
|
||||
device_ids=device_ids,
|
||||
device_type=device_type,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue