feat(trainer): 重构数据集与策略模块以支持字典形式的数据返回
This commit is contained in:
parent
9fbc9481b5
commit
4fcdc87c95
|
|
@ -25,6 +25,35 @@ def load_pkl_files(paths: List[str]):
|
||||||
|
|
||||||
return segments, total_samples
|
return segments, total_samples
|
||||||
|
|
||||||
|
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
|
||||||
|
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
||||||
|
|
||||||
|
is_bos_token = input_ids.eq(bos_token_id)
|
||||||
|
is_eos_token = input_ids.eq(eos_token_id)
|
||||||
|
|
||||||
|
token_markers[is_bos_token] = 1
|
||||||
|
token_markers[is_eos_token] = -1
|
||||||
|
|
||||||
|
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
||||||
|
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
||||||
|
loss_mask = cumulative_markers - min_cumulative
|
||||||
|
|
||||||
|
return loss_mask.to(dtype=torch.bool)
|
||||||
|
|
||||||
|
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
|
||||||
|
seq_len = input_ids.size(0)
|
||||||
|
is_user_token = input_ids.eq(user_token_id)
|
||||||
|
turn_id = is_user_token.cumsum(dim=-1)
|
||||||
|
|
||||||
|
iq = turn_id.view(seq_len, 1)
|
||||||
|
ik = turn_id.view(1, seq_len)
|
||||||
|
|
||||||
|
seq_mask = (iq <= ik) if multi_turn else (iq == ik)
|
||||||
|
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
|
||||||
|
attention_mask = seq_mask & causal_mask
|
||||||
|
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
class BaseSegmentFetcher:
|
class BaseSegmentFetcher:
|
||||||
def __init__(self, segments: List[Tensor]):
|
def __init__(self, segments: List[Tensor]):
|
||||||
|
|
@ -103,7 +132,7 @@ class BaseDataset(Dataset, ABC):
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __getitem__(self, index: int):
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
|
@ -112,7 +141,11 @@ class BaseDataset(Dataset, ABC):
|
||||||
|
|
||||||
|
|
||||||
class SeqDataset(BaseDataset):
|
class SeqDataset(BaseDataset):
|
||||||
def __init__(self, chunk_size , device='cuda'):
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size,
|
||||||
|
device='cuda'
|
||||||
|
):
|
||||||
super().__init__(chunk_size, device)
|
super().__init__(chunk_size, device)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
|
|
@ -126,13 +159,26 @@ class SeqDataset(BaseDataset):
|
||||||
x = self._fetch_data(begin_idx, end_idx).to(device=self.device, dtype=torch.long)
|
x = self._fetch_data(begin_idx, end_idx).to(device=self.device, dtype=torch.long)
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(device=self.device, dtype=torch.long)
|
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(device=self.device, dtype=torch.long)
|
||||||
|
|
||||||
return x, y
|
return {"input_ids": x, "target_ids": y}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SftDataset(BaseDataset):
|
class SftDataset(BaseDataset):
|
||||||
def __init__(self, chunk_size, device='cuda'):
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size,
|
||||||
|
bos_token_id,
|
||||||
|
eos_token_id,
|
||||||
|
user_token_id,
|
||||||
|
multi_turn=False,
|
||||||
|
device='cuda'
|
||||||
|
):
|
||||||
super().__init__(chunk_size, device)
|
super().__init__(chunk_size, device)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.user_token_id = user_token_id
|
||||||
|
self.multi_turn = multi_turn
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
@ -143,9 +189,11 @@ class SftDataset(BaseDataset):
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long)
|
x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long)
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long)
|
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long)
|
||||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "mask").to(device=self.device, dtype=torch.bool)
|
|
||||||
|
|
||||||
return x, y, loss_mask
|
loss_mask = build_loss_mask(y, self.bos_token_id, self.eos_token_id)
|
||||||
|
attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn)
|
||||||
|
|
||||||
|
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}
|
||||||
|
|
||||||
|
|
||||||
class DpoDataset(BaseDataset):
|
class DpoDataset(BaseDataset):
|
||||||
|
|
@ -165,7 +213,7 @@ class DpoDataset(BaseDataset):
|
||||||
chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(device=self.device, dtype=torch.bool)
|
chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(device=self.device, dtype=torch.bool)
|
||||||
rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(device=self.device, dtype=torch.bool)
|
rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(device=self.device, dtype=torch.bool)
|
||||||
|
|
||||||
return chosen, rejected, chosen_mask, rejected_mask
|
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
||||||
|
|
||||||
|
|
||||||
class PpoDataset(BaseDataset):
|
class PpoDataset(BaseDataset):
|
||||||
|
|
@ -187,7 +235,7 @@ class PpoDataset(BaseDataset):
|
||||||
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs").to(self.device),
|
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs").to(self.device),
|
||||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards").to(self.device)
|
rewards = self._fetch_data(begin_idx, end_idx, "rewards").to(self.device)
|
||||||
|
|
||||||
return input_ids, actions, logprobs, rewards
|
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
|
||||||
|
|
||||||
|
|
||||||
class DatasetLoader:
|
class DatasetLoader:
|
||||||
|
|
@ -196,12 +244,20 @@ class DatasetLoader:
|
||||||
train_type: Literal["seq", "sft", "dpo"],
|
train_type: Literal["seq", "sft", "dpo"],
|
||||||
load_path: Union[str, List[str]],
|
load_path: Union[str, List[str]],
|
||||||
max_len: int,
|
max_len: int,
|
||||||
device: str
|
device: str,
|
||||||
|
**kwargs
|
||||||
) -> BaseDataset:
|
) -> BaseDataset:
|
||||||
|
|
||||||
dataset_router: Dict[str, Callable[[int, torch.device], BaseDataset]] = {
|
dataset_router: Dict[str, Callable[[int, torch.device], BaseDataset]] = {
|
||||||
"seq": lambda m_len, device: SeqDataset(m_len, device=device),
|
"seq": lambda m_len, device: SeqDataset(m_len, device=device),
|
||||||
"sft": lambda m_len, device: SftDataset(m_len, device=device),
|
"sft": lambda m_len, device: SftDataset(
|
||||||
|
m_len,
|
||||||
|
device=device,
|
||||||
|
bos_token_id=kwargs.get("bos_token_id"),
|
||||||
|
eos_token_id=kwargs.get("eos_token_id"),
|
||||||
|
user_token_id=kwargs.get("user_token_id"),
|
||||||
|
multi_turn=kwargs.get("multi_turn", False)
|
||||||
|
),
|
||||||
"dpo": lambda m_len, device: DpoDataset(m_len, device=device),
|
"dpo": lambda m_len, device: DpoDataset(m_len, device=device),
|
||||||
}
|
}
|
||||||
dataset = dataset_router[train_type](max_len, device)
|
dataset = dataset_router[train_type](max_len, device)
|
||||||
|
|
|
||||||
|
|
@ -32,42 +32,13 @@ def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id:
|
||||||
|
|
||||||
return (token_logprobs * valid_mask).sum(dim=-1)
|
return (token_logprobs * valid_mask).sum(dim=-1)
|
||||||
|
|
||||||
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
|
|
||||||
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
|
||||||
|
|
||||||
is_bos_token = input_ids.eq(bos_token_id)
|
|
||||||
is_eos_token = input_ids.eq(eos_token_id)
|
|
||||||
|
|
||||||
token_markers[is_bos_token] = 1
|
|
||||||
token_markers[is_eos_token] = -1
|
|
||||||
|
|
||||||
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
|
||||||
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
|
||||||
loss_mask = cumulative_markers - min_cumulative
|
|
||||||
|
|
||||||
return loss_mask.to(dtype=torch.bool)
|
|
||||||
|
|
||||||
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
|
|
||||||
bsz, seq_len = input_ids.size()
|
|
||||||
is_user_token = input_ids.eq(user_token_id)
|
|
||||||
turn_id = is_user_token.cumsum(dim=-1)
|
|
||||||
|
|
||||||
iq = turn_id.view(bsz, seq_len, 1)
|
|
||||||
ik = turn_id.view(bsz, 1, seq_len)
|
|
||||||
|
|
||||||
seq_mask = (iq <= ik) if multi_turn else (iq == ik)
|
|
||||||
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
|
|
||||||
attention_mask = seq_mask & causal_mask
|
|
||||||
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
|
|
||||||
class BaseStrategy(ABC):
|
class BaseStrategy(ABC):
|
||||||
def __init__(self, model: nn.Module):
|
def __init__(self, model: nn.Module):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||||
|
|
@ -78,47 +49,45 @@ class SeqStrategy(BaseStrategy):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__(model)
|
super().__init__(model)
|
||||||
|
|
||||||
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
x, y = batch
|
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
||||||
B, L = x.size()
|
B, L = input_ids.size()
|
||||||
logits: Tensor = self.model(x)["logits"]
|
logits: Tensor = self.model(input_ids=input_ids)["logits"]
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
logits.view(B * L, -1), y.flatten()
|
input=logits.view(B * L, -1),
|
||||||
|
target=target_ids.flatten()
|
||||||
)
|
)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class SftStrategy(BaseStrategy):
|
class SftStrategy(BaseStrategy):
|
||||||
def __init__(
|
def __init__(self, model: nn.Module):
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
bos_id: int,
|
|
||||||
eos_id: int,
|
|
||||||
user_token_id: int,
|
|
||||||
multi_turn: bool
|
|
||||||
):
|
|
||||||
super().__init__(model)
|
super().__init__(model)
|
||||||
|
|
||||||
self.loss_mask_builder = lambda x: build_loss_mask(x, bos_id, eos_id)
|
|
||||||
self.attn_mask_builder = lambda x: build_attention_mask(x, user_token_id, multi_turn)
|
|
||||||
|
|
||||||
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
x, y, loss_mask = batch
|
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
||||||
B, L = x.size()
|
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
|
||||||
ignore_idx = -1
|
|
||||||
|
|
||||||
logits: Tensor = self.model(x)["logits"]
|
ignore_index = -100
|
||||||
masked_y = y.masked_fill(loss_mask == 0, ignore_idx)
|
B, L = input_ids.size()
|
||||||
|
|
||||||
|
logits: Tensor = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
input_mask=attn_mask
|
||||||
|
)["logits"]
|
||||||
|
|
||||||
|
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
logits.view(B * L, -1),
|
input=logits.view(B * L, -1),
|
||||||
masked_y.flatten(),
|
target=target_ids.flatten(),
|
||||||
ignore_index=ignore_idx
|
ignore_index=ignore_index
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class DpoStrategy(BaseStrategy):
|
class DpoStrategy(BaseStrategy):
|
||||||
def __init__(self, model, pad_token_id, beta):
|
def __init__(self, model, pad_token_id, beta):
|
||||||
super().__init__(model)
|
super().__init__(model)
|
||||||
|
|
@ -131,7 +100,8 @@ class DpoStrategy(BaseStrategy):
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
|
|
||||||
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||||
good_ids, bad_ids, good_mask, bad_mask = batch
|
good_ids, bad_ids = batch["chosen"], batch["rejected"]
|
||||||
|
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
|
||||||
|
|
||||||
log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id)
|
log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id)
|
||||||
log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id)
|
log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id)
|
||||||
|
|
|
||||||
|
|
@ -46,10 +46,10 @@ def test_env():
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return (
|
return {
|
||||||
torch.randint(0, 1000, (64,)),
|
"input_ids": torch.randint(0, 1000, (64,)),
|
||||||
torch.randint(0, 1000, (64,))
|
"target_ids": torch.randint(0, 1000, (64,))
|
||||||
)
|
}
|
||||||
|
|
||||||
dataset = DummyDataset()
|
dataset = DummyDataset()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue