diff --git a/khaosz/__init__.py b/khaosz/__init__.py index b4bbcf5..23f4ff9 100644 --- a/khaosz/__init__.py +++ b/khaosz/__init__.py @@ -18,9 +18,13 @@ from khaosz.core.generator import ( RetrievalGenerator, EmbeddingEncoder ) -from khaosz.trainer.trainer import Trainer -from khaosz.trainer.dataset import SeqDataset, SftDataset, DpoDataset, BaseDataset - +from khaosz.trainer import ( + Trainer, + DatasetLoader, + TrainConfig, + StrategyFactory, + SchedulerFactory +) __all__ = [ # model @@ -40,10 +44,10 @@ __all__ = [ # trainer "Trainer", - "SeqDataset", - "SftDataset", - "DpoDataset", - "BaseDataset", + "DatasetLoader", + "TrainConfig", + "StrategyFactory", + "SchedulerFactory", # utils "Retriever", diff --git a/khaosz/trainer/__init__.py b/khaosz/trainer/__init__.py index 7630e6a..6d40793 100644 --- a/khaosz/trainer/__init__.py +++ b/khaosz/trainer/__init__.py @@ -1,6 +1,12 @@ from khaosz.trainer.dataset import DatasetLoader from khaosz.trainer.trainer import Trainer -from khaosz.trainer.strategy import TrainConfig, CosineScheduleConfig, SgdrScheduleConfig +from khaosz.trainer.strategy import ( + TrainConfig, + CosineScheduleConfig, + SgdrScheduleConfig, + StrategyFactory, + SchedulerFactory +) __all__ = [ "DatasetLoader", @@ -8,4 +14,6 @@ __all__ = [ "TrainConfig", "CosineScheduleConfig", "SgdrScheduleConfig", + "StrategyFactory", + "SchedulerFactory" ] \ No newline at end of file diff --git a/khaosz/trainer/dataset.py b/khaosz/trainer/dataset.py index 908447e..3cf466e 100644 --- a/khaosz/trainer/dataset.py +++ b/khaosz/trainer/dataset.py @@ -257,7 +257,7 @@ class DatasetLoader: bos_token_id=kwargs.get("bos_token_id"), eos_token_id=kwargs.get("eos_token_id"), user_token_id=kwargs.get("user_token_id"), - multi_turn=kwargs.get("multi_turn", False) + multi_turn=kwargs.get("multi_turn") ), "dpo": lambda m_len, device: DpoDataset(m_len, device=device), } diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index e94b122..696bf83 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -177,9 +177,10 @@ class StrategyFactory: @dataclass class TrainConfig: - train_type: str = field( - default_factory=["seq", "sft", "dpo"], - metadata={"help": "Type of training."} + + strategy: BaseStrategy = field( + default=None, + metadata={"help": "Training strategy."} ) dataset: Dataset = field( default=None, @@ -217,10 +218,6 @@ class TrainConfig: default=3407, metadata={"help": "Random seed."} ) - dpo_beta: float = field( - default=0.1, - metadata={"help": "DPO beta."} - ) def get_kwargs(self)-> Dict[str, Any]: config_dict = asdict(self) @@ -228,117 +225,191 @@ class TrainConfig: @dataclass -class ScheduleConfig: +class ScheduleConfig(ABC): schedule_type: str = field( - default_factory=["cosine", "sgdr"], - metadata={"help": "Type of learning rate schedule."} + default="cosine", + metadata={ + "help": "Type of learning rate schedule.", + "choices": ["cosine", "sgdr"] + } ) - warning_step: int = field( + warmup_steps: int = field( default=1000, - metadata= {"help": "Warning up step."} - ) - @abstractmethod - def get_kwargs(self)-> Dict[str, Any]: - raise NotImplementedError - - -@dataclass -class CosineScheduleConfig(ScheduleConfig): - total_iters: int = field( - default=None, - metadata={"help": "Total iterations for cosine schedule."} + metadata={"help": "Number of warmup steps."} ) min_rate: float = field( default=0.05, - metadata={"help": "Minimum rate for cosine schedule."} + metadata={"help": "Minimum learning rate multiplier."} + ) + + @abstractmethod + def get_kwargs(self) -> Dict[str, Any]: + raise NotImplementedError + + def validate(self) -> None: + """Validate configuration parameters.""" + if self.warmup_steps < 0: + raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}") + if not 0 <= self.min_rate <= 1: + raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}") + + +@dataclass +class CosineScheduleConfig(ScheduleConfig): + total_steps: int = field( # 更准确的命名 + default=None, + metadata={"help": "Total training steps for cosine schedule."} ) schedule_type: Literal["cosine"] = "cosine" def get_kwargs(self) -> Dict[str, Any]: + if self.total_steps is None: + raise ValueError("total_steps must be specified for cosine schedule") + return { "schedule_type": self.schedule_type, - "warning_step": self.warning_step, - "lr_decay_iters": self.total_iters - self.warning_step, + "warmup_steps": self.warmup_steps, + "lr_decay_steps": self.total_steps - self.warmup_steps, "min_rate": self.min_rate } + + def validate(self) -> None: + super().validate() + if self.total_steps is not None and self.total_steps <= self.warmup_steps: + raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})") + @dataclass class SgdrScheduleConfig(ScheduleConfig): cycle_length: int = field( default=1000, - metadata={"help": "Cycle length for sgdr schedule."} + metadata={"help": "Length of the first cycle in steps."} ) - min_rate: float = field( - default=0.05, - metadata={"help": "Minimum rate for sgdr schedule."} - ) - T_mult: int = field( + t_mult: int = field( default=2, - metadata={"help": "T_mult for sgdr schedule."} + metadata={"help": "Multiplier for cycle length growth."} ) schedule_type: Literal["sgdr"] = "sgdr" def get_kwargs(self) -> Dict[str, Any]: return { "schedule_type": self.schedule_type, - "warning_step": self.warning_step, + "warmup_steps": self.warmup_steps, "cycle_length": self.cycle_length, "min_rate": self.min_rate, - "T_mult": self.T_mult + "t_mult": self.t_mult } + + def validate(self) -> None: + super().validate() + if self.cycle_length <= 0: + raise ValueError(f"cycle_length must be positive, got {self.cycle_length}") + if self.t_mult < 1: + raise ValueError(f"t_mult must be >= 1, got {self.t_mult}") class SchedulerFactory: - + """Factory for creating learning rate schedule functions.""" + @staticmethod def get_sgdr_schedule( - warning_step: int, + warmup_steps: int, cycle_length: int, - min_rate: float = 0.1, - T_mult: int = 2 + min_rate: float = 0.05, + t_mult: int = 2 ) -> Callable[[int], float]: - - def sgdr_schedule(now_iter: int) -> float: - if now_iter < warning_step: - return max(min_rate, now_iter / warning_step) - - adjusted_iter = now_iter - warning_step - total_cycles, current_cycle = 0, 0 - while adjusted_iter >= cycle_length * (T_mult ** total_cycles): - current_cycle += 1 - total_cycles += 1 + """ + Create SGDR (Stochastic Gradient Descent with Warm Restarts) schedule. + + Args: + warmup_steps: Number of warmup steps + cycle_length: Length of the first cycle + min_rate: Minimum learning rate multiplier + t_mult: Cycle length multiplier - cycle_start = sum(cycle_length * (T_mult ** i) for i in range(current_cycle)) - cycle_pos = adjusted_iter - cycle_start + Returns: + Schedule function that takes current step and returns LR multiplier + """ + + def sgdr_schedule(current_step: int) -> float: + # Warmup phase + if current_step < warmup_steps: + return max(min_rate, current_step / warmup_steps) - cycle_length_current = cycle_length * (T_mult ** current_cycle) - return max(min_rate, 0.5 * (1 + math.cos(math.pi * cycle_pos / cycle_length_current))) + # SGDR phase + steps_since_warmup = current_step - warmup_steps + + # Find current cycle and position within cycle + cycle_start = 0 + current_cycle_length = cycle_length + cycle_index = 0 + + while steps_since_warmup >= cycle_start + current_cycle_length: + cycle_start += current_cycle_length + current_cycle_length *= t_mult + cycle_index += 1 + + position_in_cycle = steps_since_warmup - cycle_start + progress = position_in_cycle / current_cycle_length + + # Cosine annealing within cycle + return max(min_rate, 0.5 * (1 + math.cos(math.pi * progress))) return sgdr_schedule @staticmethod - def get_cosine_warmup_schedule( - warning_step: int, - lr_decay_iters: int, - min_rate: float = 0.1 + def get_cosine_schedule( + warmup_steps: int, + lr_decay_steps: int, + min_rate: float = 0.05 ) -> Callable[[int], float]: - - def cosine_warmup_schedule(now_iter: int) -> float: - if now_iter <= warning_step: - return max(min_rate, now_iter / warning_step) - else: - rate = (now_iter - warning_step) / (lr_decay_iters - warning_step) - return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * rate))) + """ + Create cosine decay schedule with warmup. - return cosine_warmup_schedule + Args: + warmup_steps: Number of warmup steps + lr_decay_steps: Number of steps for cosine decay after warmup + min_rate: Minimum learning rate multiplier + + Returns: + Schedule function that takes current step and returns LR multiplier + """ + + def cosine_schedule(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return max(min_rate, current_step / warmup_steps) + else: + # Cosine decay + decay_progress = (current_step - warmup_steps) / lr_decay_steps + decay_progress = min(decay_progress, 1.0) # Clamp at 1.0 + return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * decay_progress))) + + return cosine_schedule @staticmethod - def load_schedule_fn(**kwargs): - strategy = kwargs.pop("schedule_type") - if strategy == "cosine": - return SchedulerFactory.get_cosine_warmup_schedule(**kwargs) - elif strategy == "sgdr": + def create_schedule(config: ScheduleConfig) -> Callable[[int], float]: + """ + Create schedule from configuration. + + Args: + config: Schedule configuration instance + + Returns: + Schedule function + """ + config.validate() + kwargs = config.get_kwargs() + return SchedulerFactory.load_schedule_fn(**kwargs) + + @staticmethod + def load_schedule_fn(**kwargs) -> Callable[[int], float]: + schedule_type = kwargs.pop("schedule_type") + + if schedule_type == "cosine": + return SchedulerFactory.get_cosine_schedule(**kwargs) + elif schedule_type == "sgdr": return SchedulerFactory.get_sgdr_schedule(**kwargs) else: - raise ValueError(f"Invalid schedule type: {strategy}") + raise ValueError(f"Unsupported schedule type: {schedule_type}") \ No newline at end of file diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index d906232..2af62e3 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -8,19 +8,23 @@ from torch.utils.data import DataLoader, RandomSampler from tqdm import tqdm from khaosz.core import ModelParameter, Checkpoint -from khaosz.trainer.strategy import SchedulerFactory, StrategyFactory, TrainConfig, ScheduleConfig +from khaosz.trainer.strategy import SchedulerFactory, TrainConfig, ScheduleConfig class Trainer: def __init__( self, - parameter: ModelParameter + parameter: ModelParameter, + train_config: TrainConfig, + schedule_config: ScheduleConfig ): self.checkpoint = Checkpoint( model=parameter.model, tokenizer=parameter.tokenizer, config=parameter.config, ) + self.train_config = train_config + self.schedule_config = schedule_config def save_checkpoint( self, @@ -35,12 +39,11 @@ class Trainer: def train( self, - train_config: TrainConfig, - schedule_config: ScheduleConfig, train_checkpoint: Optional[Checkpoint] = None ) -> Checkpoint: + train_config = self.train_config + schedule_config = self.schedule_config assert schedule_config.schedule_type in ["cosine", "sgdr"] - assert train_config.train_type in ["seq", "sft", "dpo"] if train_checkpoint: self.checkpoint = train_checkpoint @@ -60,19 +63,6 @@ class Trainer: **schedule_config.get_kwargs() ) - strategy_kwargs = { - "bos_token_id": self.checkpoint.tokenizer.bos_id, - "eos_token_id": self.checkpoint.tokenizer.eos_id, - "pad_token_id": self.checkpoint.tokenizer.pad_id, - "dpo_beta": train_config.dpo_beta - } - - strategy = StrategyFactory.load( - self.checkpoint.model, - train_config.train_type, - **strategy_kwargs - ) - scheduler = LambdaLR( train_config.optimizer, lambda_scheduler_fn, @@ -98,7 +88,7 @@ class Trainer: ) for batch in progress_bar: #forward - loss = strategy(batch) + loss = train_config.strategy(batch) loss_list.append(loss.item()) #backward loss.backward() diff --git a/train.py b/train.py index 73695bf..b7274fc 100644 --- a/train.py +++ b/train.py @@ -5,6 +5,7 @@ import torch from torch.optim import AdamW from khaosz.core import ParameterLoader from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig +from khaosz.trainer import StrategyFactory PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) @@ -46,19 +47,26 @@ def train( cache_files = get_files(data_root_path) - dataset_kwargs = { + strategy_kwargs = { "multi_turn": multi_turn, "bos_token_id": parameter.tokenizer.bos_id, "eos_token_id": parameter.tokenizer.eos_id, - "user_token_id":parameter.tokenizer.encode("<|user|>")[0] + "user_token_id":parameter.tokenizer.encode("<|user|>")[0], + "dpo_beta": dpo_beta } + strategy = StrategyFactory.load( + model, + train_type + **strategy_kwargs + ) + dataset = DatasetLoader.load( train_type=train_type, load_path=cache_files, max_len=parameter.config.m_len, device=device, - dataset_kwargs=dataset_kwargs + dataset_kwargs=strategy_kwargs ) param_groups = [ @@ -73,7 +81,7 @@ def train( ) train_config = TrainConfig( - train_type=train_type, + strategy=strategy, dataset=dataset, optimizer=optim, ckpt_dir=ckpt_dir, @@ -83,7 +91,6 @@ def train( n_iter_step=n_iter_step, max_grad_norm=max_grad_norm, random_seed=random_seed, - dpo_beta=dpo_beta ) schedule_config = CosineScheduleConfig( @@ -91,11 +98,13 @@ def train( total_iters=len(dataset) * n_epoch // batch_size, ) - trainer = Trainer(parameter) - trainer.train( + trainer = Trainer( + parameter=parameter, train_config=train_config, schedule_config=schedule_config, ) + trainer.train() + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train the Transformer model.")