chore: 修改拼写错误问题

This commit is contained in:
ViperEkura 2026-04-06 09:28:16 +08:00
parent 3f67e53088
commit 3fee87897d
2 changed files with 6 additions and 6 deletions

View File

@ -76,7 +76,7 @@ class Transformer(AutoModel):
def __init__(self, config: ModelConfig): def __init__(self, config: ModelConfig):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.rotary_embeding = RotaryEmbedding( self.rotary_embedding = RotaryEmbedding(
config.dim // config.n_heads, config.max_len config.dim // config.n_heads, config.max_len
) )
self.embed_tokens = Embedding(config.vocab_size, config.dim) self.embed_tokens = Embedding(config.vocab_size, config.dim)
@ -152,7 +152,7 @@ class Transformer(AutoModel):
assert input_ids.ndim == 2 assert input_ids.ndim == 2
x = self.embed_tokens(input_ids) x = self.embed_tokens(input_ids)
rotary_emb = self.rotary_embeding(x, start_pos) rotary_emb = self.rotary_embedding(x, start_pos)
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True) attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)

View File

@ -237,14 +237,14 @@ class DPOStrategy(BaseStrategy):
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"] chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"] chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"]
contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0) concat_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0) concat_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction) log_pi = get_logprobs(self.model, concat_ids, concat_mask, self.reduction)
with torch.no_grad(): with torch.no_grad():
log_ref = get_logprobs( log_ref = get_logprobs(
self.ref_model, contact_ids, contact_mask, self.reduction self.ref_model, concat_ids, concat_mask, self.reduction
) )
log_pi_chosen = log_pi[: chosen_ids.shape[0]] log_pi_chosen = log_pi[: chosen_ids.shape[0]]