feat: 增加 label smoothing 设置
This commit is contained in:
parent
6d6ef6dbb6
commit
e35cb0d84a
|
|
@ -5,7 +5,6 @@ import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from khaosz.data import DatasetLoader
|
from khaosz.data import DatasetLoader
|
||||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||||
|
|
@ -14,14 +13,6 @@ from khaosz.parallel import get_rank
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
def parse_device_ids(s: Optional[str]) -> Optional[List[int]]:
|
|
||||||
if s is None or s.strip() == "":
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
return [int(x.strip()) for x in s.split(",") if x.strip()]
|
|
||||||
except ValueError as e:
|
|
||||||
raise argparse.ArgumentTypeError(f"Invalid device_ids format: {s}. Expected comma-separated integers like '0,1,2'.")
|
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||||
|
|
||||||
|
|
@ -44,6 +35,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.")
|
parser.add_argument("--window_size", type=int, default=None, help="the max length of the input sequence.")
|
||||||
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.")
|
parser.add_argument("--stride", type=int, default=None, help="the step size of the input sequence.")
|
||||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||||
|
parser.add_argument("--label_smoothing", type=int, default=0.1, help="cross_entropy function label smoothing parameter")
|
||||||
|
|
||||||
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
|
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
|
||||||
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
|
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
|
||||||
|
|
@ -51,7 +43,6 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
||||||
|
|
||||||
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
||||||
parser.add_argument("--device_ids", type=parse_device_ids, default=None, help="Device IDs to use.")
|
|
||||||
parser.add_argument("--device_type", type=str, default="cuda", help="Device type to use.")
|
parser.add_argument("--device_type", type=str, default="cuda", help="Device type to use.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
@ -101,13 +92,13 @@ def train(
|
||||||
adamw_beta2: float,
|
adamw_beta2: float,
|
||||||
adamw_weight_decay: float,
|
adamw_weight_decay: float,
|
||||||
max_grad_norm: float,
|
max_grad_norm: float,
|
||||||
|
label_smoothing: float,
|
||||||
random_seed: int,
|
random_seed: int,
|
||||||
num_workers: int,
|
num_workers: int,
|
||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
window_size: int,
|
window_size: int,
|
||||||
stride: int,
|
stride: int,
|
||||||
nprocs: int,
|
nprocs: int,
|
||||||
device_ids: List[int],
|
|
||||||
device_type: str,
|
device_type: str,
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo"]
|
assert train_type in ["seq", "sft", "dpo"]
|
||||||
|
|
@ -126,6 +117,7 @@ def train(
|
||||||
"bos_token_id": parameter.tokenizer.bos_id,
|
"bos_token_id": parameter.tokenizer.bos_id,
|
||||||
"eos_token_id": parameter.tokenizer.eos_id,
|
"eos_token_id": parameter.tokenizer.eos_id,
|
||||||
"pad_token_id": parameter.tokenizer.pad_id,
|
"pad_token_id": parameter.tokenizer.pad_id,
|
||||||
|
"label_smoothing": label_smoothing
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset = DatasetLoader.load(
|
dataset = DatasetLoader.load(
|
||||||
|
|
@ -165,7 +157,6 @@ def train(
|
||||||
nprocs=nprocs,
|
nprocs=nprocs,
|
||||||
parallel_wrapper=ddp_wrap,
|
parallel_wrapper=ddp_wrap,
|
||||||
state_dict_fn=prepare_checkpoint,
|
state_dict_fn=prepare_checkpoint,
|
||||||
device_ids=device_ids,
|
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
extra_kwargs=kwargs,
|
extra_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue