feat(trainer): 重构数据集与策略模块以支持字典形式的数据返回
This commit is contained in:
parent
9fbc9481b5
commit
4fcdc87c95
|
|
@ -25,6 +25,35 @@ def load_pkl_files(paths: List[str]):
|
|||
|
||||
return segments, total_samples
|
||||
|
||||
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
|
||||
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
||||
|
||||
is_bos_token = input_ids.eq(bos_token_id)
|
||||
is_eos_token = input_ids.eq(eos_token_id)
|
||||
|
||||
token_markers[is_bos_token] = 1
|
||||
token_markers[is_eos_token] = -1
|
||||
|
||||
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
||||
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
||||
loss_mask = cumulative_markers - min_cumulative
|
||||
|
||||
return loss_mask.to(dtype=torch.bool)
|
||||
|
||||
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
|
||||
seq_len = input_ids.size(0)
|
||||
is_user_token = input_ids.eq(user_token_id)
|
||||
turn_id = is_user_token.cumsum(dim=-1)
|
||||
|
||||
iq = turn_id.view(seq_len, 1)
|
||||
ik = turn_id.view(1, seq_len)
|
||||
|
||||
seq_mask = (iq <= ik) if multi_turn else (iq == ik)
|
||||
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
|
||||
attention_mask = seq_mask & causal_mask
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
class BaseSegmentFetcher:
|
||||
def __init__(self, segments: List[Tensor]):
|
||||
|
|
@ -103,7 +132,7 @@ class BaseDataset(Dataset, ABC):
|
|||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, index: int):
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
|
@ -112,7 +141,11 @@ class BaseDataset(Dataset, ABC):
|
|||
|
||||
|
||||
class SeqDataset(BaseDataset):
|
||||
def __init__(self, chunk_size , device='cuda'):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size,
|
||||
device='cuda'
|
||||
):
|
||||
super().__init__(chunk_size, device)
|
||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
|
||||
|
|
@ -126,13 +159,26 @@ class SeqDataset(BaseDataset):
|
|||
x = self._fetch_data(begin_idx, end_idx).to(device=self.device, dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(device=self.device, dtype=torch.long)
|
||||
|
||||
return x, y
|
||||
return {"input_ids": x, "target_ids": y}
|
||||
|
||||
|
||||
|
||||
class SftDataset(BaseDataset):
|
||||
def __init__(self, chunk_size, device='cuda'):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size,
|
||||
bos_token_id,
|
||||
eos_token_id,
|
||||
user_token_id,
|
||||
multi_turn=False,
|
||||
device='cuda'
|
||||
):
|
||||
super().__init__(chunk_size, device)
|
||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.user_token_id = user_token_id
|
||||
self.multi_turn = multi_turn
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
|
@ -143,9 +189,11 @@ class SftDataset(BaseDataset):
|
|||
|
||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long)
|
||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "mask").to(device=self.device, dtype=torch.bool)
|
||||
|
||||
return x, y, loss_mask
|
||||
loss_mask = build_loss_mask(y, self.bos_token_id, self.eos_token_id)
|
||||
attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn)
|
||||
|
||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}
|
||||
|
||||
|
||||
class DpoDataset(BaseDataset):
|
||||
|
|
@ -165,7 +213,7 @@ class DpoDataset(BaseDataset):
|
|||
chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(device=self.device, dtype=torch.bool)
|
||||
rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(device=self.device, dtype=torch.bool)
|
||||
|
||||
return chosen, rejected, chosen_mask, rejected_mask
|
||||
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
||||
|
||||
|
||||
class PpoDataset(BaseDataset):
|
||||
|
|
@ -187,7 +235,7 @@ class PpoDataset(BaseDataset):
|
|||
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs").to(self.device),
|
||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards").to(self.device)
|
||||
|
||||
return input_ids, actions, logprobs, rewards
|
||||
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
|
||||
|
||||
|
||||
class DatasetLoader:
|
||||
|
|
@ -196,12 +244,20 @@ class DatasetLoader:
|
|||
train_type: Literal["seq", "sft", "dpo"],
|
||||
load_path: Union[str, List[str]],
|
||||
max_len: int,
|
||||
device: str
|
||||
device: str,
|
||||
**kwargs
|
||||
) -> BaseDataset:
|
||||
|
||||
dataset_router: Dict[str, Callable[[int, torch.device], BaseDataset]] = {
|
||||
"seq": lambda m_len, device: SeqDataset(m_len, device=device),
|
||||
"sft": lambda m_len, device: SftDataset(m_len, device=device),
|
||||
"sft": lambda m_len, device: SftDataset(
|
||||
m_len,
|
||||
device=device,
|
||||
bos_token_id=kwargs.get("bos_token_id"),
|
||||
eos_token_id=kwargs.get("eos_token_id"),
|
||||
user_token_id=kwargs.get("user_token_id"),
|
||||
multi_turn=kwargs.get("multi_turn", False)
|
||||
),
|
||||
"dpo": lambda m_len, device: DpoDataset(m_len, device=device),
|
||||
}
|
||||
dataset = dataset_router[train_type](max_len, device)
|
||||
|
|
|
|||
|
|
@ -32,42 +32,13 @@ def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id:
|
|||
|
||||
return (token_logprobs * valid_mask).sum(dim=-1)
|
||||
|
||||
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
|
||||
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
||||
|
||||
is_bos_token = input_ids.eq(bos_token_id)
|
||||
is_eos_token = input_ids.eq(eos_token_id)
|
||||
|
||||
token_markers[is_bos_token] = 1
|
||||
token_markers[is_eos_token] = -1
|
||||
|
||||
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
||||
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
||||
loss_mask = cumulative_markers - min_cumulative
|
||||
|
||||
return loss_mask.to(dtype=torch.bool)
|
||||
|
||||
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
|
||||
bsz, seq_len = input_ids.size()
|
||||
is_user_token = input_ids.eq(user_token_id)
|
||||
turn_id = is_user_token.cumsum(dim=-1)
|
||||
|
||||
iq = turn_id.view(bsz, seq_len, 1)
|
||||
ik = turn_id.view(bsz, 1, seq_len)
|
||||
|
||||
seq_mask = (iq <= ik) if multi_turn else (iq == ik)
|
||||
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
|
||||
attention_mask = seq_mask & causal_mask
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
def __init__(self, model: nn.Module):
|
||||
self.model = model
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||
|
|
@ -78,47 +49,45 @@ class SeqStrategy(BaseStrategy):
|
|||
def __init__(self, model):
|
||||
super().__init__(model)
|
||||
|
||||
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||
x, y = batch
|
||||
B, L = x.size()
|
||||
logits: Tensor = self.model(x)["logits"]
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
||||
B, L = input_ids.size()
|
||||
logits: Tensor = self.model(input_ids=input_ids)["logits"]
|
||||
|
||||
loss = F.cross_entropy(
|
||||
logits.view(B * L, -1), y.flatten()
|
||||
input=logits.view(B * L, -1),
|
||||
target=target_ids.flatten()
|
||||
)
|
||||
return loss
|
||||
|
||||
|
||||
class SftStrategy(BaseStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
bos_id: int,
|
||||
eos_id: int,
|
||||
user_token_id: int,
|
||||
multi_turn: bool
|
||||
):
|
||||
def __init__(self, model: nn.Module):
|
||||
super().__init__(model)
|
||||
|
||||
self.loss_mask_builder = lambda x: build_loss_mask(x, bos_id, eos_id)
|
||||
self.attn_mask_builder = lambda x: build_attention_mask(x, user_token_id, multi_turn)
|
||||
|
||||
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||
x, y, loss_mask = batch
|
||||
B, L = x.size()
|
||||
ignore_idx = -1
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
||||
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
|
||||
|
||||
logits: Tensor = self.model(x)["logits"]
|
||||
masked_y = y.masked_fill(loss_mask == 0, ignore_idx)
|
||||
ignore_index = -100
|
||||
B, L = input_ids.size()
|
||||
|
||||
logits: Tensor = self.model(
|
||||
input_ids=input_ids,
|
||||
input_mask=attn_mask
|
||||
)["logits"]
|
||||
|
||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||
|
||||
loss = F.cross_entropy(
|
||||
logits.view(B * L, -1),
|
||||
masked_y.flatten(),
|
||||
ignore_index=ignore_idx
|
||||
input=logits.view(B * L, -1),
|
||||
target=target_ids.flatten(),
|
||||
ignore_index=ignore_index
|
||||
)
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class DpoStrategy(BaseStrategy):
|
||||
def __init__(self, model, pad_token_id, beta):
|
||||
super().__init__(model)
|
||||
|
|
@ -131,7 +100,8 @@ class DpoStrategy(BaseStrategy):
|
|||
self.beta = beta
|
||||
|
||||
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||
good_ids, bad_ids, good_mask, bad_mask = batch
|
||||
good_ids, bad_ids = batch["chosen"], batch["rejected"]
|
||||
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
|
||||
|
||||
log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id)
|
||||
log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id)
|
||||
|
|
|
|||
|
|
@ -46,10 +46,10 @@ def test_env():
|
|||
return self.length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return (
|
||||
torch.randint(0, 1000, (64,)),
|
||||
torch.randint(0, 1000, (64,))
|
||||
)
|
||||
return {
|
||||
"input_ids": torch.randint(0, 1000, (64,)),
|
||||
"target_ids": torch.randint(0, 1000, (64,))
|
||||
}
|
||||
|
||||
dataset = DummyDataset()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue