chore: 修改拼写错误问题

This commit is contained in:
ViperEkura 2026-04-06 09:28:16 +08:00
parent 3f67e53088
commit 3fee87897d
2 changed files with 6 additions and 6 deletions

View File

@ -76,7 +76,7 @@ class Transformer(AutoModel):
def __init__(self, config: ModelConfig):
super().__init__(config)
self.config = config
self.rotary_embeding = RotaryEmbedding(
self.rotary_embedding = RotaryEmbedding(
config.dim // config.n_heads, config.max_len
)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
@ -152,7 +152,7 @@ class Transformer(AutoModel):
assert input_ids.ndim == 2
x = self.embed_tokens(input_ids)
rotary_emb = self.rotary_embeding(x, start_pos)
rotary_emb = self.rotary_embedding(x, start_pos)
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)

View File

@ -237,14 +237,14 @@ class DPOStrategy(BaseStrategy):
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"]
contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
concat_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
concat_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction)
log_pi = get_logprobs(self.model, concat_ids, concat_mask, self.reduction)
with torch.no_grad():
log_ref = get_logprobs(
self.ref_model, contact_ids, contact_mask, self.reduction
self.ref_model, concat_ids, concat_mask, self.reduction
)
log_pi_chosen = log_pi[: chosen_ids.shape[0]]