diff --git a/khaosz/trainer/data_util.py b/khaosz/trainer/data_util.py index 02c34cb5..38fda05 100644 --- a/khaosz/trainer/data_util.py +++ b/khaosz/trainer/data_util.py @@ -110,12 +110,11 @@ class MutiSegmentFetcher: class BaseDataset(Dataset, ABC): - def __init__(self, chunk_size: int, device: str): + def __init__(self, chunk_size: int): super().__init__() self.segments: MutiSeg = {} self.chunk_size = chunk_size self.total_samples = 0 - self.device = device def save(self, save_path: str): keys = list(self.segments.keys()) @@ -148,9 +147,8 @@ class SeqDataset(BaseDataset): def __init__( self, chunk_size, - device='cuda' ): - super().__init__(chunk_size, device) + super().__init__(chunk_size) self.fetcher = MutiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: @@ -160,8 +158,8 @@ class SeqDataset(BaseDataset): begin_idx = index * self.chunk_size end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1) - x = self._fetch_data(begin_idx, end_idx).to(device=self.device, dtype=torch.long) - y = self._fetch_data(begin_idx + 1, end_idx + 1).to(device=self.device, dtype=torch.long) + x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) + y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long) return {"input_ids": x, "target_ids": y} @@ -175,9 +173,8 @@ class SftDataset(BaseDataset): eos_token_id, user_token_id, multi_turn=False, - device='cuda' ): - super().__init__(chunk_size, device) + super().__init__(chunk_size) self.fetcher = MutiSegmentFetcher(self.segments) self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id @@ -191,8 +188,8 @@ class SftDataset(BaseDataset): begin_idx = index * self.chunk_size end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1) - x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long) - y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long) + x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) + y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long) # fix the eos_token_id bug(change target_ids to input_ids) loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id) @@ -202,8 +199,8 @@ class SftDataset(BaseDataset): class DpoDataset(BaseDataset): - def __init__(self, chunk_size: int, device="cuda"): - super().__init__(chunk_size, device) + def __init__(self, chunk_size: int): + super().__init__(chunk_size) self.fetcher = MutiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: @@ -213,17 +210,17 @@ class DpoDataset(BaseDataset): start_idx = index * self.chunk_size end_idx = min(start_idx + self.chunk_size, self.total_samples - 1) - chosen = self._fetch_data(start_idx, end_idx, "chosen").to(device=self.device, dtype=torch.long) - rejected = self._fetch_data(start_idx, end_idx, "rejected").to(device=self.device, dtype=torch.long) - chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(device=self.device, dtype=torch.bool) - rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(device=self.device, dtype=torch.bool) + chosen = self._fetch_data(start_idx, end_idx, "chosen").to(dtype=torch.long) + rejected = self._fetch_data(start_idx, end_idx, "rejected").to(dtype=torch.long) + chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(dtype=torch.bool) + rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(dtype=torch.bool) return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask} class PpoDataset(BaseDataset): - def __init__(self, chunk_size: int, device="cuda"): - super().__init__(chunk_size, device) + def __init__(self, chunk_size: int): + super().__init__(chunk_size) self.fetcher = MutiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: @@ -233,12 +230,11 @@ class PpoDataset(BaseDataset): begin_idx = index * self.chunk_size end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1) - - input_ids = self._fetch_data(begin_idx, end_idx, "input_ids").to(self.device), - actions = self._fetch_data(begin_idx, end_idx, "actions").to(self.device), - logprobs = self._fetch_data(begin_idx, end_idx, "logprobs").to(self.device), - rewards = self._fetch_data(begin_idx, end_idx, "rewards").to(self.device) + input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"), + actions = self._fetch_data(begin_idx, end_idx, "actions"), + logprobs = self._fetch_data(begin_idx, end_idx, "logprobs"), + rewards = self._fetch_data(begin_idx, end_idx, "rewards") return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards} @@ -249,23 +245,21 @@ class DatasetLoader: train_type: Literal["seq", "sft", "dpo"], load_path: Union[str, List[str]], max_len: int, - device: str, **kwargs ) -> BaseDataset: - dataset_router: Dict[str, Callable[[int, torch.device], BaseDataset]] = { - "seq": lambda m_len, device: SeqDataset(m_len, device=device), - "sft": lambda m_len, device: SftDataset( - m_len, - device=device, + dataset_router: Dict[str, Callable[[int], BaseDataset]] = { + "seq": lambda max_len: SeqDataset(max_len), + "sft": lambda max_len: SftDataset( + max_len, bos_token_id=kwargs.get("bos_token_id"), eos_token_id=kwargs.get("eos_token_id"), user_token_id=kwargs.get("user_token_id"), multi_turn=kwargs.get("multi_turn") ), - "dpo": lambda m_len, device: DpoDataset(m_len, device=device), + "dpo": lambda max_len: DpoDataset(max_len), } - dataset = dataset_router[train_type](max_len, device) + dataset = dataset_router[train_type](max_len) dataset.load(load_path) return dataset diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index c14f492..203bf43 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -32,10 +32,14 @@ def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: return (token_logprobs * valid_mask).sum(dim=-1) +def move_to_device(batch:Dict[str, Tensor], device: str) -> Any: + return {key: value.to(device, non_blocking=True) for key, value in batch.items()} + class BaseStrategy(ABC): - def __init__(self, model: nn.Module): + def __init__(self, model: nn.Module, device: str): self.model = model + self.device = device @abstractmethod def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: @@ -46,10 +50,11 @@ class BaseStrategy(ABC): class SeqStrategy(BaseStrategy): - def __init__(self, model): - super().__init__(model) + def __init__(self, model, device): + super().__init__(model, device) def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: + batch = move_to_device(batch, self.device) input_ids, target_ids = batch["input_ids"], batch["target_ids"] B, L = input_ids.size() logits: Tensor = self.model(input_ids=input_ids)["logits"] @@ -62,10 +67,11 @@ class SeqStrategy(BaseStrategy): class SftStrategy(BaseStrategy): - def __init__(self, model: nn.Module): - super().__init__(model) + def __init__(self, model, device): + super().__init__(model, device) def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: + batch = move_to_device(batch, self.device) input_ids, target_ids = batch["input_ids"], batch["target_ids"] loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"] @@ -89,8 +95,8 @@ class SftStrategy(BaseStrategy): class DpoStrategy(BaseStrategy): - def __init__(self, model, pad_token_id, beta): - super().__init__(model) + def __init__(self, model, device, pad_token_id, beta): + super().__init__(model, device) ref_model = copy.deepcopy(self.model) ref_model.requires_grad_(False) ref_model.eval() @@ -100,6 +106,7 @@ class DpoStrategy(BaseStrategy): self.beta = beta def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor: + batch = move_to_device(batch, self.device) good_ids, bad_ids = batch["chosen"], batch["rejected"] good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"] @@ -156,12 +163,13 @@ class PpoStrategy(BaseStrategy): class StrategyFactory: - def load(model, train_type, **kwargs): + def load(model, train_type, device, **kwargs): train_strategy: Dict[str, Callable[[], BaseStrategy]] = { - "seq": lambda: SeqStrategy(model), - "sft": lambda: SftStrategy(model), + "seq": lambda: SeqStrategy(model, device), + "sft": lambda: SftStrategy(model, device), "dpo": lambda: DpoStrategy( model, + device, kwargs.get("pad_token_id"), kwargs.get("dpo_beta") ) diff --git a/train.py b/train.py index feb7f8b..c87d514 100644 --- a/train.py +++ b/train.py @@ -59,14 +59,14 @@ def train( strategy = StrategyFactory.load( model, train_type, + device, **kwargs ) dataset = DatasetLoader.load( train_type=train_type, load_path=cache_files, - max_len=parameter.config.m_len, - device=device, + max_len=parameter.config.m_len **kwargs )