feat( StrategyFactory): 添加 SFT 策略初始化参数并完善工厂方法调用
This commit is contained in:
parent
676fdd59d7
commit
053f4a4dad
|
|
@ -12,7 +12,7 @@ from abc import ABC, abstractmethod
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
|
||||||
|
|
||||||
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id):
|
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int):
|
||||||
input_mask = input_ids.ne(pad_token_id)
|
input_mask = input_ids.ne(pad_token_id)
|
||||||
logits = model(input_ids, input_mask)["logits"]
|
logits = model(input_ids, input_mask)["logits"]
|
||||||
log_probs = torch.log_softmax(logits, dim=-1)
|
log_probs = torch.log_softmax(logits, dim=-1)
|
||||||
|
|
@ -47,7 +47,7 @@ def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) ->
|
||||||
|
|
||||||
return loss_mask.to(dtype=torch.bool)
|
return loss_mask.to(dtype=torch.bool)
|
||||||
|
|
||||||
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool = False) -> Tensor:
|
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
|
||||||
bsz, seq_len = input_ids.size()
|
bsz, seq_len = input_ids.size()
|
||||||
is_user_token = input_ids.eq(user_token_id)
|
is_user_token = input_ids.eq(user_token_id)
|
||||||
turn_id = is_user_token.cumsum(dim=-1)
|
turn_id = is_user_token.cumsum(dim=-1)
|
||||||
|
|
@ -90,9 +90,19 @@ class SeqStrategy(BaseStrategy):
|
||||||
|
|
||||||
|
|
||||||
class SftStrategy(BaseStrategy):
|
class SftStrategy(BaseStrategy):
|
||||||
def __init__(self, model):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
bos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
user_token_id: int,
|
||||||
|
multi_turn: bool
|
||||||
|
):
|
||||||
super().__init__(model)
|
super().__init__(model)
|
||||||
|
|
||||||
|
self.loss_mask_builder = lambda x: build_loss_mask(x, bos_id, eos_id)
|
||||||
|
self.attn_mask_builder = lambda x: build_attention_mask(x, user_token_id, multi_turn)
|
||||||
|
|
||||||
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||||
x, y, loss_mask = batch
|
x, y, loss_mask = batch
|
||||||
B, L = x.size()
|
B, L = x.size()
|
||||||
|
|
@ -179,7 +189,7 @@ class StrategyFactory:
|
||||||
def load(model, train_type, **kwargs):
|
def load(model, train_type, **kwargs):
|
||||||
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
||||||
"seq": lambda: SeqStrategy(model),
|
"seq": lambda: SeqStrategy(model),
|
||||||
"sft": lambda: SftStrategy(model),
|
"sft": lambda: SftStrategy(model, kwargs.pop("bos_token_id"), kwargs.pop("eos_token_id"), kwargs.pop("multi_turn")),
|
||||||
"dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta"))
|
"dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta"))
|
||||||
}
|
}
|
||||||
strategy = train_strategy[train_type]()
|
strategy = train_strategy[train_type]()
|
||||||
|
|
|
||||||
|
|
@ -86,9 +86,11 @@ class Trainer:
|
||||||
|
|
||||||
strategy = StrategyFactory.load(
|
strategy = StrategyFactory.load(
|
||||||
self.model,
|
self.model,
|
||||||
train_config.train_type,
|
train_type=train_config.train_type,
|
||||||
self.tokenizer.pad_id,
|
bos_token_id=self.tokenizer.bos_id,
|
||||||
train_config.dpo_beta
|
eos_token_id=self.tokenizer.eos_id,
|
||||||
|
pad_token_id=self.tokenizer.pad_id,
|
||||||
|
dpo_beta=train_config.dpo_beta
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler = LambdaLR(
|
scheduler = LambdaLR(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue