feat(trainer): 重构训练配置与策略工厂引入
This commit is contained in:
parent
2dc7b5bda8
commit
fa43ed2943
|
|
@ -18,9 +18,13 @@ from khaosz.core.generator import (
|
||||||
RetrievalGenerator,
|
RetrievalGenerator,
|
||||||
EmbeddingEncoder
|
EmbeddingEncoder
|
||||||
)
|
)
|
||||||
from khaosz.trainer.trainer import Trainer
|
from khaosz.trainer import (
|
||||||
from khaosz.trainer.dataset import SeqDataset, SftDataset, DpoDataset, BaseDataset
|
Trainer,
|
||||||
|
DatasetLoader,
|
||||||
|
TrainConfig,
|
||||||
|
StrategyFactory,
|
||||||
|
SchedulerFactory
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# model
|
# model
|
||||||
|
|
@ -40,10 +44,10 @@ __all__ = [
|
||||||
|
|
||||||
# trainer
|
# trainer
|
||||||
"Trainer",
|
"Trainer",
|
||||||
"SeqDataset",
|
"DatasetLoader",
|
||||||
"SftDataset",
|
"TrainConfig",
|
||||||
"DpoDataset",
|
"StrategyFactory",
|
||||||
"BaseDataset",
|
"SchedulerFactory",
|
||||||
|
|
||||||
# utils
|
# utils
|
||||||
"Retriever",
|
"Retriever",
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,12 @@
|
||||||
from khaosz.trainer.dataset import DatasetLoader
|
from khaosz.trainer.dataset import DatasetLoader
|
||||||
from khaosz.trainer.trainer import Trainer
|
from khaosz.trainer.trainer import Trainer
|
||||||
from khaosz.trainer.strategy import TrainConfig, CosineScheduleConfig, SgdrScheduleConfig
|
from khaosz.trainer.strategy import (
|
||||||
|
TrainConfig,
|
||||||
|
CosineScheduleConfig,
|
||||||
|
SgdrScheduleConfig,
|
||||||
|
StrategyFactory,
|
||||||
|
SchedulerFactory
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DatasetLoader",
|
"DatasetLoader",
|
||||||
|
|
@ -8,4 +14,6 @@ __all__ = [
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
"CosineScheduleConfig",
|
"CosineScheduleConfig",
|
||||||
"SgdrScheduleConfig",
|
"SgdrScheduleConfig",
|
||||||
|
"StrategyFactory",
|
||||||
|
"SchedulerFactory"
|
||||||
]
|
]
|
||||||
|
|
@ -257,7 +257,7 @@ class DatasetLoader:
|
||||||
bos_token_id=kwargs.get("bos_token_id"),
|
bos_token_id=kwargs.get("bos_token_id"),
|
||||||
eos_token_id=kwargs.get("eos_token_id"),
|
eos_token_id=kwargs.get("eos_token_id"),
|
||||||
user_token_id=kwargs.get("user_token_id"),
|
user_token_id=kwargs.get("user_token_id"),
|
||||||
multi_turn=kwargs.get("multi_turn", False)
|
multi_turn=kwargs.get("multi_turn")
|
||||||
),
|
),
|
||||||
"dpo": lambda m_len, device: DpoDataset(m_len, device=device),
|
"dpo": lambda m_len, device: DpoDataset(m_len, device=device),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -177,9 +177,10 @@ class StrategyFactory:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainConfig:
|
class TrainConfig:
|
||||||
train_type: str = field(
|
|
||||||
default_factory=["seq", "sft", "dpo"],
|
strategy: BaseStrategy = field(
|
||||||
metadata={"help": "Type of training."}
|
default=None,
|
||||||
|
metadata={"help": "Training strategy."}
|
||||||
)
|
)
|
||||||
dataset: Dataset = field(
|
dataset: Dataset = field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|
@ -217,10 +218,6 @@ class TrainConfig:
|
||||||
default=3407,
|
default=3407,
|
||||||
metadata={"help": "Random seed."}
|
metadata={"help": "Random seed."}
|
||||||
)
|
)
|
||||||
dpo_beta: float = field(
|
|
||||||
default=0.1,
|
|
||||||
metadata={"help": "DPO beta."}
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_kwargs(self)-> Dict[str, Any]:
|
def get_kwargs(self)-> Dict[str, Any]:
|
||||||
config_dict = asdict(self)
|
config_dict = asdict(self)
|
||||||
|
|
@ -228,117 +225,191 @@ class TrainConfig:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ScheduleConfig:
|
class ScheduleConfig(ABC):
|
||||||
schedule_type: str = field(
|
schedule_type: str = field(
|
||||||
default_factory=["cosine", "sgdr"],
|
default="cosine",
|
||||||
metadata={"help": "Type of learning rate schedule."}
|
metadata={
|
||||||
|
"help": "Type of learning rate schedule.",
|
||||||
|
"choices": ["cosine", "sgdr"]
|
||||||
|
}
|
||||||
)
|
)
|
||||||
warning_step: int = field(
|
warmup_steps: int = field(
|
||||||
default=1000,
|
default=1000,
|
||||||
metadata= {"help": "Warning up step."}
|
metadata={"help": "Number of warmup steps."}
|
||||||
)
|
)
|
||||||
|
min_rate: float = field(
|
||||||
|
default=0.05,
|
||||||
|
metadata={"help": "Minimum learning rate multiplier."}
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_kwargs(self) -> Dict[str, Any]:
|
def get_kwargs(self) -> Dict[str, Any]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""Validate configuration parameters."""
|
||||||
|
if self.warmup_steps < 0:
|
||||||
|
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
|
||||||
|
if not 0 <= self.min_rate <= 1:
|
||||||
|
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CosineScheduleConfig(ScheduleConfig):
|
class CosineScheduleConfig(ScheduleConfig):
|
||||||
total_iters: int = field(
|
total_steps: int = field( # 更准确的命名
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Total iterations for cosine schedule."}
|
metadata={"help": "Total training steps for cosine schedule."}
|
||||||
)
|
|
||||||
min_rate: float = field(
|
|
||||||
default=0.05,
|
|
||||||
metadata={"help": "Minimum rate for cosine schedule."}
|
|
||||||
)
|
)
|
||||||
schedule_type: Literal["cosine"] = "cosine"
|
schedule_type: Literal["cosine"] = "cosine"
|
||||||
|
|
||||||
def get_kwargs(self) -> Dict[str, Any]:
|
def get_kwargs(self) -> Dict[str, Any]:
|
||||||
|
if self.total_steps is None:
|
||||||
|
raise ValueError("total_steps must be specified for cosine schedule")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"schedule_type": self.schedule_type,
|
"schedule_type": self.schedule_type,
|
||||||
"warning_step": self.warning_step,
|
"warmup_steps": self.warmup_steps,
|
||||||
"lr_decay_iters": self.total_iters - self.warning_step,
|
"lr_decay_steps": self.total_steps - self.warmup_steps,
|
||||||
"min_rate": self.min_rate
|
"min_rate": self.min_rate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
super().validate()
|
||||||
|
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
|
||||||
|
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SgdrScheduleConfig(ScheduleConfig):
|
class SgdrScheduleConfig(ScheduleConfig):
|
||||||
cycle_length: int = field(
|
cycle_length: int = field(
|
||||||
default=1000,
|
default=1000,
|
||||||
metadata={"help": "Cycle length for sgdr schedule."}
|
metadata={"help": "Length of the first cycle in steps."}
|
||||||
)
|
)
|
||||||
min_rate: float = field(
|
t_mult: int = field(
|
||||||
default=0.05,
|
|
||||||
metadata={"help": "Minimum rate for sgdr schedule."}
|
|
||||||
)
|
|
||||||
T_mult: int = field(
|
|
||||||
default=2,
|
default=2,
|
||||||
metadata={"help": "T_mult for sgdr schedule."}
|
metadata={"help": "Multiplier for cycle length growth."}
|
||||||
)
|
)
|
||||||
schedule_type: Literal["sgdr"] = "sgdr"
|
schedule_type: Literal["sgdr"] = "sgdr"
|
||||||
|
|
||||||
def get_kwargs(self) -> Dict[str, Any]:
|
def get_kwargs(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"schedule_type": self.schedule_type,
|
"schedule_type": self.schedule_type,
|
||||||
"warning_step": self.warning_step,
|
"warmup_steps": self.warmup_steps,
|
||||||
"cycle_length": self.cycle_length,
|
"cycle_length": self.cycle_length,
|
||||||
"min_rate": self.min_rate,
|
"min_rate": self.min_rate,
|
||||||
"T_mult": self.T_mult
|
"t_mult": self.t_mult
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
super().validate()
|
||||||
|
if self.cycle_length <= 0:
|
||||||
|
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
|
||||||
|
if self.t_mult < 1:
|
||||||
|
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
|
||||||
|
|
||||||
|
|
||||||
class SchedulerFactory:
|
class SchedulerFactory:
|
||||||
|
"""Factory for creating learning rate schedule functions."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_sgdr_schedule(
|
def get_sgdr_schedule(
|
||||||
warning_step: int,
|
warmup_steps: int,
|
||||||
cycle_length: int,
|
cycle_length: int,
|
||||||
min_rate: float = 0.1,
|
min_rate: float = 0.05,
|
||||||
T_mult: int = 2
|
t_mult: int = 2
|
||||||
) -> Callable[[int], float]:
|
) -> Callable[[int], float]:
|
||||||
|
"""
|
||||||
|
Create SGDR (Stochastic Gradient Descent with Warm Restarts) schedule.
|
||||||
|
|
||||||
def sgdr_schedule(now_iter: int) -> float:
|
Args:
|
||||||
if now_iter < warning_step:
|
warmup_steps: Number of warmup steps
|
||||||
return max(min_rate, now_iter / warning_step)
|
cycle_length: Length of the first cycle
|
||||||
|
min_rate: Minimum learning rate multiplier
|
||||||
|
t_mult: Cycle length multiplier
|
||||||
|
|
||||||
adjusted_iter = now_iter - warning_step
|
Returns:
|
||||||
total_cycles, current_cycle = 0, 0
|
Schedule function that takes current step and returns LR multiplier
|
||||||
while adjusted_iter >= cycle_length * (T_mult ** total_cycles):
|
"""
|
||||||
current_cycle += 1
|
|
||||||
total_cycles += 1
|
|
||||||
|
|
||||||
cycle_start = sum(cycle_length * (T_mult ** i) for i in range(current_cycle))
|
def sgdr_schedule(current_step: int) -> float:
|
||||||
cycle_pos = adjusted_iter - cycle_start
|
# Warmup phase
|
||||||
|
if current_step < warmup_steps:
|
||||||
|
return max(min_rate, current_step / warmup_steps)
|
||||||
|
|
||||||
cycle_length_current = cycle_length * (T_mult ** current_cycle)
|
# SGDR phase
|
||||||
return max(min_rate, 0.5 * (1 + math.cos(math.pi * cycle_pos / cycle_length_current)))
|
steps_since_warmup = current_step - warmup_steps
|
||||||
|
|
||||||
|
# Find current cycle and position within cycle
|
||||||
|
cycle_start = 0
|
||||||
|
current_cycle_length = cycle_length
|
||||||
|
cycle_index = 0
|
||||||
|
|
||||||
|
while steps_since_warmup >= cycle_start + current_cycle_length:
|
||||||
|
cycle_start += current_cycle_length
|
||||||
|
current_cycle_length *= t_mult
|
||||||
|
cycle_index += 1
|
||||||
|
|
||||||
|
position_in_cycle = steps_since_warmup - cycle_start
|
||||||
|
progress = position_in_cycle / current_cycle_length
|
||||||
|
|
||||||
|
# Cosine annealing within cycle
|
||||||
|
return max(min_rate, 0.5 * (1 + math.cos(math.pi * progress)))
|
||||||
|
|
||||||
return sgdr_schedule
|
return sgdr_schedule
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_cosine_warmup_schedule(
|
def get_cosine_schedule(
|
||||||
warning_step: int,
|
warmup_steps: int,
|
||||||
lr_decay_iters: int,
|
lr_decay_steps: int,
|
||||||
min_rate: float = 0.1
|
min_rate: float = 0.05
|
||||||
) -> Callable[[int], float]:
|
) -> Callable[[int], float]:
|
||||||
|
"""
|
||||||
|
Create cosine decay schedule with warmup.
|
||||||
|
|
||||||
def cosine_warmup_schedule(now_iter: int) -> float:
|
Args:
|
||||||
if now_iter <= warning_step:
|
warmup_steps: Number of warmup steps
|
||||||
return max(min_rate, now_iter / warning_step)
|
lr_decay_steps: Number of steps for cosine decay after warmup
|
||||||
|
min_rate: Minimum learning rate multiplier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Schedule function that takes current step and returns LR multiplier
|
||||||
|
"""
|
||||||
|
|
||||||
|
def cosine_schedule(current_step: int) -> float:
|
||||||
|
if current_step < warmup_steps:
|
||||||
|
# Linear warmup
|
||||||
|
return max(min_rate, current_step / warmup_steps)
|
||||||
else:
|
else:
|
||||||
rate = (now_iter - warning_step) / (lr_decay_iters - warning_step)
|
# Cosine decay
|
||||||
return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * rate)))
|
decay_progress = (current_step - warmup_steps) / lr_decay_steps
|
||||||
|
decay_progress = min(decay_progress, 1.0) # Clamp at 1.0
|
||||||
|
return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * decay_progress)))
|
||||||
|
|
||||||
return cosine_warmup_schedule
|
return cosine_schedule
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_schedule_fn(**kwargs):
|
def create_schedule(config: ScheduleConfig) -> Callable[[int], float]:
|
||||||
strategy = kwargs.pop("schedule_type")
|
"""
|
||||||
if strategy == "cosine":
|
Create schedule from configuration.
|
||||||
return SchedulerFactory.get_cosine_warmup_schedule(**kwargs)
|
|
||||||
elif strategy == "sgdr":
|
Args:
|
||||||
|
config: Schedule configuration instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Schedule function
|
||||||
|
"""
|
||||||
|
config.validate()
|
||||||
|
kwargs = config.get_kwargs()
|
||||||
|
return SchedulerFactory.load_schedule_fn(**kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_schedule_fn(**kwargs) -> Callable[[int], float]:
|
||||||
|
schedule_type = kwargs.pop("schedule_type")
|
||||||
|
|
||||||
|
if schedule_type == "cosine":
|
||||||
|
return SchedulerFactory.get_cosine_schedule(**kwargs)
|
||||||
|
elif schedule_type == "sgdr":
|
||||||
return SchedulerFactory.get_sgdr_schedule(**kwargs)
|
return SchedulerFactory.get_sgdr_schedule(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid schedule type: {strategy}")
|
raise ValueError(f"Unsupported schedule type: {schedule_type}")
|
||||||
|
|
||||||
|
|
@ -8,19 +8,23 @@ from torch.utils.data import DataLoader, RandomSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from khaosz.core import ModelParameter, Checkpoint
|
from khaosz.core import ModelParameter, Checkpoint
|
||||||
from khaosz.trainer.strategy import SchedulerFactory, StrategyFactory, TrainConfig, ScheduleConfig
|
from khaosz.trainer.strategy import SchedulerFactory, TrainConfig, ScheduleConfig
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parameter: ModelParameter
|
parameter: ModelParameter,
|
||||||
|
train_config: TrainConfig,
|
||||||
|
schedule_config: ScheduleConfig
|
||||||
):
|
):
|
||||||
self.checkpoint = Checkpoint(
|
self.checkpoint = Checkpoint(
|
||||||
model=parameter.model,
|
model=parameter.model,
|
||||||
tokenizer=parameter.tokenizer,
|
tokenizer=parameter.tokenizer,
|
||||||
config=parameter.config,
|
config=parameter.config,
|
||||||
)
|
)
|
||||||
|
self.train_config = train_config
|
||||||
|
self.schedule_config = schedule_config
|
||||||
|
|
||||||
def save_checkpoint(
|
def save_checkpoint(
|
||||||
self,
|
self,
|
||||||
|
|
@ -35,12 +39,11 @@ class Trainer:
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
train_config: TrainConfig,
|
|
||||||
schedule_config: ScheduleConfig,
|
|
||||||
train_checkpoint: Optional[Checkpoint] = None
|
train_checkpoint: Optional[Checkpoint] = None
|
||||||
) -> Checkpoint:
|
) -> Checkpoint:
|
||||||
|
train_config = self.train_config
|
||||||
|
schedule_config = self.schedule_config
|
||||||
assert schedule_config.schedule_type in ["cosine", "sgdr"]
|
assert schedule_config.schedule_type in ["cosine", "sgdr"]
|
||||||
assert train_config.train_type in ["seq", "sft", "dpo"]
|
|
||||||
|
|
||||||
if train_checkpoint:
|
if train_checkpoint:
|
||||||
self.checkpoint = train_checkpoint
|
self.checkpoint = train_checkpoint
|
||||||
|
|
@ -60,19 +63,6 @@ class Trainer:
|
||||||
**schedule_config.get_kwargs()
|
**schedule_config.get_kwargs()
|
||||||
)
|
)
|
||||||
|
|
||||||
strategy_kwargs = {
|
|
||||||
"bos_token_id": self.checkpoint.tokenizer.bos_id,
|
|
||||||
"eos_token_id": self.checkpoint.tokenizer.eos_id,
|
|
||||||
"pad_token_id": self.checkpoint.tokenizer.pad_id,
|
|
||||||
"dpo_beta": train_config.dpo_beta
|
|
||||||
}
|
|
||||||
|
|
||||||
strategy = StrategyFactory.load(
|
|
||||||
self.checkpoint.model,
|
|
||||||
train_config.train_type,
|
|
||||||
**strategy_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler = LambdaLR(
|
scheduler = LambdaLR(
|
||||||
train_config.optimizer,
|
train_config.optimizer,
|
||||||
lambda_scheduler_fn,
|
lambda_scheduler_fn,
|
||||||
|
|
@ -98,7 +88,7 @@ class Trainer:
|
||||||
)
|
)
|
||||||
for batch in progress_bar:
|
for batch in progress_bar:
|
||||||
#forward
|
#forward
|
||||||
loss = strategy(batch)
|
loss = train_config.strategy(batch)
|
||||||
loss_list.append(loss.item())
|
loss_list.append(loss.item())
|
||||||
#backward
|
#backward
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
|
||||||
23
train.py
23
train.py
|
|
@ -5,6 +5,7 @@ import torch
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from khaosz.core import ParameterLoader
|
from khaosz.core import ParameterLoader
|
||||||
from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig
|
from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig
|
||||||
|
from khaosz.trainer import StrategyFactory
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
@ -46,19 +47,26 @@ def train(
|
||||||
|
|
||||||
cache_files = get_files(data_root_path)
|
cache_files = get_files(data_root_path)
|
||||||
|
|
||||||
dataset_kwargs = {
|
strategy_kwargs = {
|
||||||
"multi_turn": multi_turn,
|
"multi_turn": multi_turn,
|
||||||
"bos_token_id": parameter.tokenizer.bos_id,
|
"bos_token_id": parameter.tokenizer.bos_id,
|
||||||
"eos_token_id": parameter.tokenizer.eos_id,
|
"eos_token_id": parameter.tokenizer.eos_id,
|
||||||
"user_token_id":parameter.tokenizer.encode("<|user|>")[0]
|
"user_token_id":parameter.tokenizer.encode("<|user|>")[0],
|
||||||
|
"dpo_beta": dpo_beta
|
||||||
}
|
}
|
||||||
|
|
||||||
|
strategy = StrategyFactory.load(
|
||||||
|
model,
|
||||||
|
train_type
|
||||||
|
**strategy_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
dataset = DatasetLoader.load(
|
dataset = DatasetLoader.load(
|
||||||
train_type=train_type,
|
train_type=train_type,
|
||||||
load_path=cache_files,
|
load_path=cache_files,
|
||||||
max_len=parameter.config.m_len,
|
max_len=parameter.config.m_len,
|
||||||
device=device,
|
device=device,
|
||||||
dataset_kwargs=dataset_kwargs
|
dataset_kwargs=strategy_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
param_groups = [
|
param_groups = [
|
||||||
|
|
@ -73,7 +81,7 @@ def train(
|
||||||
)
|
)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
train_type=train_type,
|
strategy=strategy,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
optimizer=optim,
|
optimizer=optim,
|
||||||
ckpt_dir=ckpt_dir,
|
ckpt_dir=ckpt_dir,
|
||||||
|
|
@ -83,7 +91,6 @@ def train(
|
||||||
n_iter_step=n_iter_step,
|
n_iter_step=n_iter_step,
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
dpo_beta=dpo_beta
|
|
||||||
)
|
)
|
||||||
|
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(
|
||||||
|
|
@ -91,11 +98,13 @@ def train(
|
||||||
total_iters=len(dataset) * n_epoch // batch_size,
|
total_iters=len(dataset) * n_epoch // batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(parameter)
|
trainer = Trainer(
|
||||||
trainer.train(
|
parameter=parameter,
|
||||||
train_config=train_config,
|
train_config=train_config,
|
||||||
schedule_config=schedule_config,
|
schedule_config=schedule_config,
|
||||||
)
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue