feat(trainer): 重构数据集与策略模块以支持字典形式的数据返回

This commit is contained in:
ViperEkura 2025-09-27 14:11:27 +08:00
parent 9fbc9481b5
commit 4fcdc87c95
3 changed files with 97 additions and 71 deletions

View File

@ -25,6 +25,35 @@ def load_pkl_files(paths: List[str]):
return segments, total_samples
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
is_bos_token = input_ids.eq(bos_token_id)
is_eos_token = input_ids.eq(eos_token_id)
token_markers[is_bos_token] = 1
token_markers[is_eos_token] = -1
cumulative_markers = torch.cumsum(token_markers, dim=-1)
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
loss_mask = cumulative_markers - min_cumulative
return loss_mask.to(dtype=torch.bool)
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
seq_len = input_ids.size(0)
is_user_token = input_ids.eq(user_token_id)
turn_id = is_user_token.cumsum(dim=-1)
iq = turn_id.view(seq_len, 1)
ik = turn_id.view(1, seq_len)
seq_mask = (iq <= ik) if multi_turn else (iq == ik)
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
attention_mask = seq_mask & causal_mask
return attention_mask
class BaseSegmentFetcher:
def __init__(self, segments: List[Tensor]):
@ -103,7 +132,7 @@ class BaseDataset(Dataset, ABC):
self.fetcher = MutiSegmentFetcher(self.segments)
@abstractmethod
def __getitem__(self, index: int):
def __getitem__(self, index: int) -> Dict[str, Tensor]:
raise NotImplementedError
def __len__(self) -> int:
@ -112,7 +141,11 @@ class BaseDataset(Dataset, ABC):
class SeqDataset(BaseDataset):
def __init__(self, chunk_size , device='cuda'):
def __init__(
self,
chunk_size,
device='cuda'
):
super().__init__(chunk_size, device)
self.fetcher = MutiSegmentFetcher(self.segments)
@ -126,13 +159,26 @@ class SeqDataset(BaseDataset):
x = self._fetch_data(begin_idx, end_idx).to(device=self.device, dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(device=self.device, dtype=torch.long)
return x, y
return {"input_ids": x, "target_ids": y}
class SftDataset(BaseDataset):
def __init__(self, chunk_size, device='cuda'):
def __init__(
self,
chunk_size,
bos_token_id,
eos_token_id,
user_token_id,
multi_turn=False,
device='cuda'
):
super().__init__(chunk_size, device)
self.fetcher = MutiSegmentFetcher(self.segments)
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.user_token_id = user_token_id
self.multi_turn = multi_turn
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
@ -143,9 +189,11 @@ class SftDataset(BaseDataset):
x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long)
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "mask").to(device=self.device, dtype=torch.bool)
return x, y, loss_mask
loss_mask = build_loss_mask(y, self.bos_token_id, self.eos_token_id)
attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn)
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}
class DpoDataset(BaseDataset):
@ -165,7 +213,7 @@ class DpoDataset(BaseDataset):
chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(device=self.device, dtype=torch.bool)
rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(device=self.device, dtype=torch.bool)
return chosen, rejected, chosen_mask, rejected_mask
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
class PpoDataset(BaseDataset):
@ -187,7 +235,7 @@ class PpoDataset(BaseDataset):
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs").to(self.device),
rewards = self._fetch_data(begin_idx, end_idx, "rewards").to(self.device)
return input_ids, actions, logprobs, rewards
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
class DatasetLoader:
@ -196,12 +244,20 @@ class DatasetLoader:
train_type: Literal["seq", "sft", "dpo"],
load_path: Union[str, List[str]],
max_len: int,
device: str
device: str,
**kwargs
) -> BaseDataset:
dataset_router: Dict[str, Callable[[int, torch.device], BaseDataset]] = {
"seq": lambda m_len, device: SeqDataset(m_len, device=device),
"sft": lambda m_len, device: SftDataset(m_len, device=device),
"sft": lambda m_len, device: SftDataset(
m_len,
device=device,
bos_token_id=kwargs.get("bos_token_id"),
eos_token_id=kwargs.get("eos_token_id"),
user_token_id=kwargs.get("user_token_id"),
multi_turn=kwargs.get("multi_turn", False)
),
"dpo": lambda m_len, device: DpoDataset(m_len, device=device),
}
dataset = dataset_router[train_type](max_len, device)

View File

@ -32,42 +32,13 @@ def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id:
return (token_logprobs * valid_mask).sum(dim=-1)
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
is_bos_token = input_ids.eq(bos_token_id)
is_eos_token = input_ids.eq(eos_token_id)
token_markers[is_bos_token] = 1
token_markers[is_eos_token] = -1
cumulative_markers = torch.cumsum(token_markers, dim=-1)
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
loss_mask = cumulative_markers - min_cumulative
return loss_mask.to(dtype=torch.bool)
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
bsz, seq_len = input_ids.size()
is_user_token = input_ids.eq(user_token_id)
turn_id = is_user_token.cumsum(dim=-1)
iq = turn_id.view(bsz, seq_len, 1)
ik = turn_id.view(bsz, 1, seq_len)
seq_mask = (iq <= ik) if multi_turn else (iq == ik)
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
attention_mask = seq_mask & causal_mask
return attention_mask
class BaseStrategy(ABC):
def __init__(self, model: nn.Module):
self.model = model
@abstractmethod
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
raise NotImplementedError
def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor:
@ -78,47 +49,45 @@ class SeqStrategy(BaseStrategy):
def __init__(self, model):
super().__init__(model)
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
x, y = batch
B, L = x.size()
logits: Tensor = self.model(x)["logits"]
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
B, L = input_ids.size()
logits: Tensor = self.model(input_ids=input_ids)["logits"]
loss = F.cross_entropy(
logits.view(B * L, -1), y.flatten()
input=logits.view(B * L, -1),
target=target_ids.flatten()
)
return loss
class SftStrategy(BaseStrategy):
def __init__(
self,
model: nn.Module,
bos_id: int,
eos_id: int,
user_token_id: int,
multi_turn: bool
):
def __init__(self, model: nn.Module):
super().__init__(model)
self.loss_mask_builder = lambda x: build_loss_mask(x, bos_id, eos_id)
self.attn_mask_builder = lambda x: build_attention_mask(x, user_token_id, multi_turn)
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
x, y, loss_mask = batch
B, L = x.size()
ignore_idx = -1
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
logits: Tensor = self.model(x)["logits"]
masked_y = y.masked_fill(loss_mask == 0, ignore_idx)
ignore_index = -100
B, L = input_ids.size()
logits: Tensor = self.model(
input_ids=input_ids,
input_mask=attn_mask
)["logits"]
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
loss = F.cross_entropy(
logits.view(B * L, -1),
masked_y.flatten(),
ignore_index=ignore_idx
input=logits.view(B * L, -1),
target=target_ids.flatten(),
ignore_index=ignore_index
)
return loss
class DpoStrategy(BaseStrategy):
def __init__(self, model, pad_token_id, beta):
super().__init__(model)
@ -131,7 +100,8 @@ class DpoStrategy(BaseStrategy):
self.beta = beta
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
good_ids, bad_ids, good_mask, bad_mask = batch
good_ids, bad_ids = batch["chosen"], batch["rejected"]
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id)
log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id)

View File

@ -46,10 +46,10 @@ def test_env():
return self.length
def __getitem__(self, idx):
return (
torch.randint(0, 1000, (64,)),
torch.randint(0, 1000, (64,))
)
return {
"input_ids": torch.randint(0, 1000, (64,)),
"target_ids": torch.randint(0, 1000, (64,))
}
dataset = DummyDataset()