chore: 修改拼写错误问题
This commit is contained in:
parent
3f67e53088
commit
3fee87897d
|
|
@ -76,7 +76,7 @@ class Transformer(AutoModel):
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: ModelConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rotary_embeding = RotaryEmbedding(
|
self.rotary_embedding = RotaryEmbedding(
|
||||||
config.dim // config.n_heads, config.max_len
|
config.dim // config.n_heads, config.max_len
|
||||||
)
|
)
|
||||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||||
|
|
@ -152,7 +152,7 @@ class Transformer(AutoModel):
|
||||||
assert input_ids.ndim == 2
|
assert input_ids.ndim == 2
|
||||||
|
|
||||||
x = self.embed_tokens(input_ids)
|
x = self.embed_tokens(input_ids)
|
||||||
rotary_emb = self.rotary_embeding(x, start_pos)
|
rotary_emb = self.rotary_embedding(x, start_pos)
|
||||||
|
|
||||||
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
|
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -237,14 +237,14 @@ class DPOStrategy(BaseStrategy):
|
||||||
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
|
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
|
||||||
chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"]
|
chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"]
|
||||||
|
|
||||||
contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
|
concat_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
|
||||||
contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
|
concat_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
|
||||||
|
|
||||||
log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction)
|
log_pi = get_logprobs(self.model, concat_ids, concat_mask, self.reduction)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
log_ref = get_logprobs(
|
log_ref = get_logprobs(
|
||||||
self.ref_model, contact_ids, contact_mask, self.reduction
|
self.ref_model, concat_ids, concat_mask, self.reduction
|
||||||
)
|
)
|
||||||
|
|
||||||
log_pi_chosen = log_pi[: chosen_ids.shape[0]]
|
log_pi_chosen = log_pi[: chosen_ids.shape[0]]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue