From 3fee87897de3a3ae5adbbeb2333cebde595d6f22 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 6 Apr 2026 09:28:16 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E4=BF=AE=E6=94=B9=E6=8B=BC=E5=86=99?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/model/transformer.py | 4 ++-- astrai/trainer/strategy.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 2198a4b..553f682 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -76,7 +76,7 @@ class Transformer(AutoModel): def __init__(self, config: ModelConfig): super().__init__(config) self.config = config - self.rotary_embeding = RotaryEmbedding( + self.rotary_embedding = RotaryEmbedding( config.dim // config.n_heads, config.max_len ) self.embed_tokens = Embedding(config.vocab_size, config.dim) @@ -152,7 +152,7 @@ class Transformer(AutoModel): assert input_ids.ndim == 2 x = self.embed_tokens(input_ids) - rotary_emb = self.rotary_embeding(x, start_pos) + rotary_emb = self.rotary_embedding(x, start_pos) attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True) diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index fd3d1d5..eb5a6e5 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -237,14 +237,14 @@ class DPOStrategy(BaseStrategy): chosen_ids, rejected_ids = batch["chosen"], batch["rejected"] chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"] - contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0) - contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0) + concat_ids = torch.cat([chosen_ids, rejected_ids], dim=0) + concat_mask = torch.cat([chosen_mask, rejected_mask], dim=0) - log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction) + log_pi = get_logprobs(self.model, concat_ids, concat_mask, self.reduction) with torch.no_grad(): log_ref = get_logprobs( - self.ref_model, contact_ids, contact_mask, self.reduction + self.ref_model, concat_ids, concat_mask, self.reduction ) log_pi_chosen = log_pi[: chosen_ids.shape[0]]