feat(model): 添加并行线性层模型支持
This commit is contained in:
parent
d9ff662e3a
commit
6fb6a15e81
|
|
@ -0,0 +1,105 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelModel(nn.Module):
|
||||||
|
def __init__(self, process_group: dist.ProcessGroup):
|
||||||
|
super().__init__()
|
||||||
|
self.process_group = process_group
|
||||||
|
self.rank = dist.get_rank(self.process_group)
|
||||||
|
self.world_size = dist.get_world_size(self.process_group)
|
||||||
|
|
||||||
|
|
||||||
|
class RowParallelLinear(ParallelModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
process_group: dist.ProcessGroup,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
reduce_results: bool = True
|
||||||
|
):
|
||||||
|
super().__init__(process_group)
|
||||||
|
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.in_features_per_rank = in_features // self.world_size
|
||||||
|
self.reduce_results = reduce_results
|
||||||
|
|
||||||
|
if in_features % self.world_size != 0:
|
||||||
|
raise ValueError(f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}")
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
output = F.linear(input, self.weight)
|
||||||
|
|
||||||
|
if self.reduce_results:
|
||||||
|
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
output += self.bias
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
||||||
|
full_weight = state_dict.get('weight')
|
||||||
|
full_bias = state_dict.get('bias')
|
||||||
|
|
||||||
|
start_idx = self.rank * self.in_features_per_rank
|
||||||
|
end_idx = start_idx + self.in_features_per_rank
|
||||||
|
weight_slice = full_weight[:, start_idx:end_idx]
|
||||||
|
self.weight.data.copy_(weight_slice)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias.data.copy_(full_bias)
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnParallelLinear(ParallelModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
process_group: dist.ProcessGroup,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
reduce_results: bool = True
|
||||||
|
):
|
||||||
|
super().__init__(process_group)
|
||||||
|
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.out_features_per_rank = out_features // self.world_size
|
||||||
|
self.reduce_results = reduce_results
|
||||||
|
|
||||||
|
if out_features % self.world_size != 0:
|
||||||
|
raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}")
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.empty(self.out_features_per_rank, self.in_features))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
output = F.linear(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
if self.reduce_results:
|
||||||
|
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: Dict[str, Tensor]):
|
||||||
|
full_weight = state_dict.get('weight')
|
||||||
|
full_bias = state_dict.get('bias')
|
||||||
|
|
||||||
|
start_idx = self.rank * self.out_features_per_rank
|
||||||
|
end_idx = start_idx + self.out_features_per_rank
|
||||||
|
weight_slice = full_weight[start_idx:end_idx, :]
|
||||||
|
self.weight.data.copy_(weight_slice)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
bias_slice = full_bias[start_idx:end_idx]
|
||||||
|
self.bias.data.copy_(bias_slice)
|
||||||
Loading…
Reference in New Issue