refactor(khaosz/tainer): 修改设备参数传递发生阶段
This commit is contained in:
parent
240ee00221
commit
465a1a9373
|
|
@ -110,12 +110,11 @@ class MutiSegmentFetcher:
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(Dataset, ABC):
|
class BaseDataset(Dataset, ABC):
|
||||||
def __init__(self, chunk_size: int, device: str):
|
def __init__(self, chunk_size: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.segments: MutiSeg = {}
|
self.segments: MutiSeg = {}
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
self.total_samples = 0
|
self.total_samples = 0
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def save(self, save_path: str):
|
def save(self, save_path: str):
|
||||||
keys = list(self.segments.keys())
|
keys = list(self.segments.keys())
|
||||||
|
|
@ -148,9 +147,8 @@ class SeqDataset(BaseDataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
device='cuda'
|
|
||||||
):
|
):
|
||||||
super().__init__(chunk_size, device)
|
super().__init__(chunk_size)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
|
|
@ -160,8 +158,8 @@ class SeqDataset(BaseDataset):
|
||||||
begin_idx = index * self.chunk_size
|
begin_idx = index * self.chunk_size
|
||||||
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
|
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx).to(device=self.device, dtype=torch.long)
|
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(device=self.device, dtype=torch.long)
|
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
||||||
|
|
||||||
return {"input_ids": x, "target_ids": y}
|
return {"input_ids": x, "target_ids": y}
|
||||||
|
|
||||||
|
|
@ -175,9 +173,8 @@ class SftDataset(BaseDataset):
|
||||||
eos_token_id,
|
eos_token_id,
|
||||||
user_token_id,
|
user_token_id,
|
||||||
multi_turn=False,
|
multi_turn=False,
|
||||||
device='cuda'
|
|
||||||
):
|
):
|
||||||
super().__init__(chunk_size, device)
|
super().__init__(chunk_size)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
|
|
@ -191,8 +188,8 @@ class SftDataset(BaseDataset):
|
||||||
begin_idx = index * self.chunk_size
|
begin_idx = index * self.chunk_size
|
||||||
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
|
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long)
|
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long)
|
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
|
||||||
|
|
||||||
# fix the eos_token_id bug(change target_ids to input_ids)
|
# fix the eos_token_id bug(change target_ids to input_ids)
|
||||||
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)
|
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)
|
||||||
|
|
@ -202,8 +199,8 @@ class SftDataset(BaseDataset):
|
||||||
|
|
||||||
|
|
||||||
class DpoDataset(BaseDataset):
|
class DpoDataset(BaseDataset):
|
||||||
def __init__(self, chunk_size: int, device="cuda"):
|
def __init__(self, chunk_size: int):
|
||||||
super().__init__(chunk_size, device)
|
super().__init__(chunk_size)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
|
|
@ -213,17 +210,17 @@ class DpoDataset(BaseDataset):
|
||||||
start_idx = index * self.chunk_size
|
start_idx = index * self.chunk_size
|
||||||
end_idx = min(start_idx + self.chunk_size, self.total_samples - 1)
|
end_idx = min(start_idx + self.chunk_size, self.total_samples - 1)
|
||||||
|
|
||||||
chosen = self._fetch_data(start_idx, end_idx, "chosen").to(device=self.device, dtype=torch.long)
|
chosen = self._fetch_data(start_idx, end_idx, "chosen").to(dtype=torch.long)
|
||||||
rejected = self._fetch_data(start_idx, end_idx, "rejected").to(device=self.device, dtype=torch.long)
|
rejected = self._fetch_data(start_idx, end_idx, "rejected").to(dtype=torch.long)
|
||||||
chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(device=self.device, dtype=torch.bool)
|
chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(dtype=torch.bool)
|
||||||
rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(device=self.device, dtype=torch.bool)
|
rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(dtype=torch.bool)
|
||||||
|
|
||||||
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
||||||
|
|
||||||
|
|
||||||
class PpoDataset(BaseDataset):
|
class PpoDataset(BaseDataset):
|
||||||
def __init__(self, chunk_size: int, device="cuda"):
|
def __init__(self, chunk_size: int):
|
||||||
super().__init__(chunk_size, device)
|
super().__init__(chunk_size)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
|
|
@ -233,12 +230,11 @@ class PpoDataset(BaseDataset):
|
||||||
|
|
||||||
begin_idx = index * self.chunk_size
|
begin_idx = index * self.chunk_size
|
||||||
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
|
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
|
||||||
|
|
||||||
|
|
||||||
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids").to(self.device),
|
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
|
||||||
actions = self._fetch_data(begin_idx, end_idx, "actions").to(self.device),
|
actions = self._fetch_data(begin_idx, end_idx, "actions"),
|
||||||
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs").to(self.device),
|
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs"),
|
||||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards").to(self.device)
|
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||||
|
|
||||||
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
|
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
|
||||||
|
|
||||||
|
|
@ -249,23 +245,21 @@ class DatasetLoader:
|
||||||
train_type: Literal["seq", "sft", "dpo"],
|
train_type: Literal["seq", "sft", "dpo"],
|
||||||
load_path: Union[str, List[str]],
|
load_path: Union[str, List[str]],
|
||||||
max_len: int,
|
max_len: int,
|
||||||
device: str,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> BaseDataset:
|
) -> BaseDataset:
|
||||||
|
|
||||||
dataset_router: Dict[str, Callable[[int, torch.device], BaseDataset]] = {
|
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
|
||||||
"seq": lambda m_len, device: SeqDataset(m_len, device=device),
|
"seq": lambda max_len: SeqDataset(max_len),
|
||||||
"sft": lambda m_len, device: SftDataset(
|
"sft": lambda max_len: SftDataset(
|
||||||
m_len,
|
max_len,
|
||||||
device=device,
|
|
||||||
bos_token_id=kwargs.get("bos_token_id"),
|
bos_token_id=kwargs.get("bos_token_id"),
|
||||||
eos_token_id=kwargs.get("eos_token_id"),
|
eos_token_id=kwargs.get("eos_token_id"),
|
||||||
user_token_id=kwargs.get("user_token_id"),
|
user_token_id=kwargs.get("user_token_id"),
|
||||||
multi_turn=kwargs.get("multi_turn")
|
multi_turn=kwargs.get("multi_turn")
|
||||||
),
|
),
|
||||||
"dpo": lambda m_len, device: DpoDataset(m_len, device=device),
|
"dpo": lambda max_len: DpoDataset(max_len),
|
||||||
}
|
}
|
||||||
dataset = dataset_router[train_type](max_len, device)
|
dataset = dataset_router[train_type](max_len)
|
||||||
dataset.load(load_path)
|
dataset.load(load_path)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
|
||||||
|
|
@ -32,10 +32,14 @@ def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id:
|
||||||
|
|
||||||
return (token_logprobs * valid_mask).sum(dim=-1)
|
return (token_logprobs * valid_mask).sum(dim=-1)
|
||||||
|
|
||||||
|
def move_to_device(batch:Dict[str, Tensor], device: str) -> Any:
|
||||||
|
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
|
||||||
|
|
||||||
|
|
||||||
class BaseStrategy(ABC):
|
class BaseStrategy(ABC):
|
||||||
def __init__(self, model: nn.Module):
|
def __init__(self, model: nn.Module, device: str):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.device = device
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
|
|
@ -46,10 +50,11 @@ class BaseStrategy(ABC):
|
||||||
|
|
||||||
|
|
||||||
class SeqStrategy(BaseStrategy):
|
class SeqStrategy(BaseStrategy):
|
||||||
def __init__(self, model):
|
def __init__(self, model, device):
|
||||||
super().__init__(model)
|
super().__init__(model, device)
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
|
batch = move_to_device(batch, self.device)
|
||||||
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
||||||
B, L = input_ids.size()
|
B, L = input_ids.size()
|
||||||
logits: Tensor = self.model(input_ids=input_ids)["logits"]
|
logits: Tensor = self.model(input_ids=input_ids)["logits"]
|
||||||
|
|
@ -62,10 +67,11 @@ class SeqStrategy(BaseStrategy):
|
||||||
|
|
||||||
|
|
||||||
class SftStrategy(BaseStrategy):
|
class SftStrategy(BaseStrategy):
|
||||||
def __init__(self, model: nn.Module):
|
def __init__(self, model, device):
|
||||||
super().__init__(model)
|
super().__init__(model, device)
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
|
batch = move_to_device(batch, self.device)
|
||||||
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
||||||
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
|
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
|
||||||
|
|
||||||
|
|
@ -89,8 +95,8 @@ class SftStrategy(BaseStrategy):
|
||||||
|
|
||||||
|
|
||||||
class DpoStrategy(BaseStrategy):
|
class DpoStrategy(BaseStrategy):
|
||||||
def __init__(self, model, pad_token_id, beta):
|
def __init__(self, model, device, pad_token_id, beta):
|
||||||
super().__init__(model)
|
super().__init__(model, device)
|
||||||
ref_model = copy.deepcopy(self.model)
|
ref_model = copy.deepcopy(self.model)
|
||||||
ref_model.requires_grad_(False)
|
ref_model.requires_grad_(False)
|
||||||
ref_model.eval()
|
ref_model.eval()
|
||||||
|
|
@ -100,6 +106,7 @@ class DpoStrategy(BaseStrategy):
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
|
|
||||||
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
||||||
|
batch = move_to_device(batch, self.device)
|
||||||
good_ids, bad_ids = batch["chosen"], batch["rejected"]
|
good_ids, bad_ids = batch["chosen"], batch["rejected"]
|
||||||
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
|
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
|
||||||
|
|
||||||
|
|
@ -156,12 +163,13 @@ class PpoStrategy(BaseStrategy):
|
||||||
|
|
||||||
class StrategyFactory:
|
class StrategyFactory:
|
||||||
|
|
||||||
def load(model, train_type, **kwargs):
|
def load(model, train_type, device, **kwargs):
|
||||||
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
||||||
"seq": lambda: SeqStrategy(model),
|
"seq": lambda: SeqStrategy(model, device),
|
||||||
"sft": lambda: SftStrategy(model),
|
"sft": lambda: SftStrategy(model, device),
|
||||||
"dpo": lambda: DpoStrategy(
|
"dpo": lambda: DpoStrategy(
|
||||||
model,
|
model,
|
||||||
|
device,
|
||||||
kwargs.get("pad_token_id"),
|
kwargs.get("pad_token_id"),
|
||||||
kwargs.get("dpo_beta")
|
kwargs.get("dpo_beta")
|
||||||
)
|
)
|
||||||
|
|
|
||||||
4
train.py
4
train.py
|
|
@ -59,14 +59,14 @@ def train(
|
||||||
strategy = StrategyFactory.load(
|
strategy = StrategyFactory.load(
|
||||||
model,
|
model,
|
||||||
train_type,
|
train_type,
|
||||||
|
device,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = DatasetLoader.load(
|
dataset = DatasetLoader.load(
|
||||||
train_type=train_type,
|
train_type=train_type,
|
||||||
load_path=cache_files,
|
load_path=cache_files,
|
||||||
max_len=parameter.config.m_len,
|
max_len=parameter.config.m_len
|
||||||
device=device,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue