refactor(khaosz/tainer): 修改设备参数传递发生阶段

This commit is contained in:
ViperEkura 2025-10-04 12:12:21 +08:00
parent 240ee00221
commit 465a1a9373
3 changed files with 45 additions and 43 deletions

View File

@ -110,12 +110,11 @@ class MutiSegmentFetcher:
class BaseDataset(Dataset, ABC):
def __init__(self, chunk_size: int, device: str):
def __init__(self, chunk_size: int):
super().__init__()
self.segments: MutiSeg = {}
self.chunk_size = chunk_size
self.total_samples = 0
self.device = device
def save(self, save_path: str):
keys = list(self.segments.keys())
@ -148,9 +147,8 @@ class SeqDataset(BaseDataset):
def __init__(
self,
chunk_size,
device='cuda'
):
super().__init__(chunk_size, device)
super().__init__(chunk_size)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
@ -160,8 +158,8 @@ class SeqDataset(BaseDataset):
begin_idx = index * self.chunk_size
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
x = self._fetch_data(begin_idx, end_idx).to(device=self.device, dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(device=self.device, dtype=torch.long)
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
return {"input_ids": x, "target_ids": y}
@ -175,9 +173,8 @@ class SftDataset(BaseDataset):
eos_token_id,
user_token_id,
multi_turn=False,
device='cuda'
):
super().__init__(chunk_size, device)
super().__init__(chunk_size)
self.fetcher = MutiSegmentFetcher(self.segments)
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
@ -191,8 +188,8 @@ class SftDataset(BaseDataset):
begin_idx = index * self.chunk_size
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long)
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
# fix the eos_token_id bug(change target_ids to input_ids)
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)
@ -202,8 +199,8 @@ class SftDataset(BaseDataset):
class DpoDataset(BaseDataset):
def __init__(self, chunk_size: int, device="cuda"):
super().__init__(chunk_size, device)
def __init__(self, chunk_size: int):
super().__init__(chunk_size)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
@ -213,17 +210,17 @@ class DpoDataset(BaseDataset):
start_idx = index * self.chunk_size
end_idx = min(start_idx + self.chunk_size, self.total_samples - 1)
chosen = self._fetch_data(start_idx, end_idx, "chosen").to(device=self.device, dtype=torch.long)
rejected = self._fetch_data(start_idx, end_idx, "rejected").to(device=self.device, dtype=torch.long)
chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(device=self.device, dtype=torch.bool)
rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(device=self.device, dtype=torch.bool)
chosen = self._fetch_data(start_idx, end_idx, "chosen").to(dtype=torch.long)
rejected = self._fetch_data(start_idx, end_idx, "rejected").to(dtype=torch.long)
chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(dtype=torch.bool)
rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(dtype=torch.bool)
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
class PpoDataset(BaseDataset):
def __init__(self, chunk_size: int, device="cuda"):
super().__init__(chunk_size, device)
def __init__(self, chunk_size: int):
super().__init__(chunk_size)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
@ -234,11 +231,10 @@ class PpoDataset(BaseDataset):
begin_idx = index * self.chunk_size
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids").to(self.device),
actions = self._fetch_data(begin_idx, end_idx, "actions").to(self.device),
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs").to(self.device),
rewards = self._fetch_data(begin_idx, end_idx, "rewards").to(self.device)
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
actions = self._fetch_data(begin_idx, end_idx, "actions"),
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs"),
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
@ -249,23 +245,21 @@ class DatasetLoader:
train_type: Literal["seq", "sft", "dpo"],
load_path: Union[str, List[str]],
max_len: int,
device: str,
**kwargs
) -> BaseDataset:
dataset_router: Dict[str, Callable[[int, torch.device], BaseDataset]] = {
"seq": lambda m_len, device: SeqDataset(m_len, device=device),
"sft": lambda m_len, device: SftDataset(
m_len,
device=device,
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
"seq": lambda max_len: SeqDataset(max_len),
"sft": lambda max_len: SftDataset(
max_len,
bos_token_id=kwargs.get("bos_token_id"),
eos_token_id=kwargs.get("eos_token_id"),
user_token_id=kwargs.get("user_token_id"),
multi_turn=kwargs.get("multi_turn")
),
"dpo": lambda m_len, device: DpoDataset(m_len, device=device),
"dpo": lambda max_len: DpoDataset(max_len),
}
dataset = dataset_router[train_type](max_len, device)
dataset = dataset_router[train_type](max_len)
dataset.load(load_path)
return dataset

View File

@ -32,10 +32,14 @@ def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id:
return (token_logprobs * valid_mask).sum(dim=-1)
def move_to_device(batch:Dict[str, Tensor], device: str) -> Any:
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
class BaseStrategy(ABC):
def __init__(self, model: nn.Module):
def __init__(self, model: nn.Module, device: str):
self.model = model
self.device = device
@abstractmethod
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
@ -46,10 +50,11 @@ class BaseStrategy(ABC):
class SeqStrategy(BaseStrategy):
def __init__(self, model):
super().__init__(model)
def __init__(self, model, device):
super().__init__(model, device)
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
B, L = input_ids.size()
logits: Tensor = self.model(input_ids=input_ids)["logits"]
@ -62,10 +67,11 @@ class SeqStrategy(BaseStrategy):
class SftStrategy(BaseStrategy):
def __init__(self, model: nn.Module):
super().__init__(model)
def __init__(self, model, device):
super().__init__(model, device)
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
@ -89,8 +95,8 @@ class SftStrategy(BaseStrategy):
class DpoStrategy(BaseStrategy):
def __init__(self, model, pad_token_id, beta):
super().__init__(model)
def __init__(self, model, device, pad_token_id, beta):
super().__init__(model, device)
ref_model = copy.deepcopy(self.model)
ref_model.requires_grad_(False)
ref_model.eval()
@ -100,6 +106,7 @@ class DpoStrategy(BaseStrategy):
self.beta = beta
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
batch = move_to_device(batch, self.device)
good_ids, bad_ids = batch["chosen"], batch["rejected"]
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
@ -156,12 +163,13 @@ class PpoStrategy(BaseStrategy):
class StrategyFactory:
def load(model, train_type, **kwargs):
def load(model, train_type, device, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SeqStrategy(model),
"sft": lambda: SftStrategy(model),
"seq": lambda: SeqStrategy(model, device),
"sft": lambda: SftStrategy(model, device),
"dpo": lambda: DpoStrategy(
model,
device,
kwargs.get("pad_token_id"),
kwargs.get("dpo_beta")
)

View File

@ -59,14 +59,14 @@ def train(
strategy = StrategyFactory.load(
model,
train_type,
device,
**kwargs
)
dataset = DatasetLoader.load(
train_type=train_type,
load_path=cache_files,
max_len=parameter.config.m_len,
device=device,
max_len=parameter.config.m_len
**kwargs
)