feat( StrategyFactory): 添加 SFT 策略初始化参数并完善工厂方法调用

This commit is contained in:
ViperEkura 2025-09-27 13:24:16 +08:00
parent 676fdd59d7
commit 053f4a4dad
2 changed files with 19 additions and 7 deletions

View File

@ -12,7 +12,7 @@ from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id): def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int):
input_mask = input_ids.ne(pad_token_id) input_mask = input_ids.ne(pad_token_id)
logits = model(input_ids, input_mask)["logits"] logits = model(input_ids, input_mask)["logits"]
log_probs = torch.log_softmax(logits, dim=-1) log_probs = torch.log_softmax(logits, dim=-1)
@ -47,7 +47,7 @@ def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) ->
return loss_mask.to(dtype=torch.bool) return loss_mask.to(dtype=torch.bool)
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool = False) -> Tensor: def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
bsz, seq_len = input_ids.size() bsz, seq_len = input_ids.size()
is_user_token = input_ids.eq(user_token_id) is_user_token = input_ids.eq(user_token_id)
turn_id = is_user_token.cumsum(dim=-1) turn_id = is_user_token.cumsum(dim=-1)
@ -90,8 +90,18 @@ class SeqStrategy(BaseStrategy):
class SftStrategy(BaseStrategy): class SftStrategy(BaseStrategy):
def __init__(self, model): def __init__(
self,
model: nn.Module,
bos_id: int,
eos_id: int,
user_token_id: int,
multi_turn: bool
):
super().__init__(model) super().__init__(model)
self.loss_mask_builder = lambda x: build_loss_mask(x, bos_id, eos_id)
self.attn_mask_builder = lambda x: build_attention_mask(x, user_token_id, multi_turn)
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor: def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
x, y, loss_mask = batch x, y, loss_mask = batch
@ -179,7 +189,7 @@ class StrategyFactory:
def load(model, train_type, **kwargs): def load(model, train_type, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = { train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SeqStrategy(model), "seq": lambda: SeqStrategy(model),
"sft": lambda: SftStrategy(model), "sft": lambda: SftStrategy(model, kwargs.pop("bos_token_id"), kwargs.pop("eos_token_id"), kwargs.pop("multi_turn")),
"dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta")) "dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta"))
} }
strategy = train_strategy[train_type]() strategy = train_strategy[train_type]()

View File

@ -86,9 +86,11 @@ class Trainer:
strategy = StrategyFactory.load( strategy = StrategyFactory.load(
self.model, self.model,
train_config.train_type, train_type=train_config.train_type,
self.tokenizer.pad_id, bos_token_id=self.tokenizer.bos_id,
train_config.dpo_beta eos_token_id=self.tokenizer.eos_id,
pad_token_id=self.tokenizer.pad_id,
dpo_beta=train_config.dpo_beta
) )
scheduler = LambdaLR( scheduler = LambdaLR(