chore: 修改拼写错误问题
This commit is contained in:
parent
3f67e53088
commit
3fee87897d
|
|
@ -76,7 +76,7 @@ class Transformer(AutoModel):
|
|||
def __init__(self, config: ModelConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.rotary_embeding = RotaryEmbedding(
|
||||
self.rotary_embedding = RotaryEmbedding(
|
||||
config.dim // config.n_heads, config.max_len
|
||||
)
|
||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||
|
|
@ -152,7 +152,7 @@ class Transformer(AutoModel):
|
|||
assert input_ids.ndim == 2
|
||||
|
||||
x = self.embed_tokens(input_ids)
|
||||
rotary_emb = self.rotary_embeding(x, start_pos)
|
||||
rotary_emb = self.rotary_embedding(x, start_pos)
|
||||
|
||||
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -237,14 +237,14 @@ class DPOStrategy(BaseStrategy):
|
|||
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
|
||||
chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"]
|
||||
|
||||
contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
|
||||
contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
|
||||
concat_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
|
||||
concat_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
|
||||
|
||||
log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction)
|
||||
log_pi = get_logprobs(self.model, concat_ids, concat_mask, self.reduction)
|
||||
|
||||
with torch.no_grad():
|
||||
log_ref = get_logprobs(
|
||||
self.ref_model, contact_ids, contact_mask, self.reduction
|
||||
self.ref_model, concat_ids, concat_mask, self.reduction
|
||||
)
|
||||
|
||||
log_pi_chosen = log_pi[: chosen_ids.shape[0]]
|
||||
|
|
|
|||
Loading…
Reference in New Issue