diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 2198a4b..553f682 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -76,7 +76,7 @@ class Transformer(AutoModel): def __init__(self, config: ModelConfig): super().__init__(config) self.config = config - self.rotary_embeding = RotaryEmbedding( + self.rotary_embedding = RotaryEmbedding( config.dim // config.n_heads, config.max_len ) self.embed_tokens = Embedding(config.vocab_size, config.dim) @@ -152,7 +152,7 @@ class Transformer(AutoModel): assert input_ids.ndim == 2 x = self.embed_tokens(input_ids) - rotary_emb = self.rotary_embeding(x, start_pos) + rotary_emb = self.rotary_embedding(x, start_pos) attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True) diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index fd3d1d5..eb5a6e5 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -237,14 +237,14 @@ class DPOStrategy(BaseStrategy): chosen_ids, rejected_ids = batch["chosen"], batch["rejected"] chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"] - contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0) - contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0) + concat_ids = torch.cat([chosen_ids, rejected_ids], dim=0) + concat_mask = torch.cat([chosen_mask, rejected_mask], dim=0) - log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction) + log_pi = get_logprobs(self.model, concat_ids, concat_mask, self.reduction) with torch.no_grad(): log_ref = get_logprobs( - self.ref_model, contact_ids, contact_mask, self.reduction + self.ref_model, concat_ids, concat_mask, self.reduction ) log_pi_chosen = log_pi[: chosen_ids.shape[0]]