fix(parallel): 修改列并行线性层结果聚合方式
This commit is contained in:
parent
6fb6a15e81
commit
fb85aaf6a6
|
|
@ -68,14 +68,14 @@ class ColumnParallelLinear(ParallelModel):
|
|||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
reduce_results: bool = True
|
||||
gather_results: bool = True
|
||||
):
|
||||
super().__init__(process_group)
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.out_features_per_rank = out_features // self.world_size
|
||||
self.reduce_results = reduce_results
|
||||
self.gather_results = gather_results
|
||||
|
||||
if out_features % self.world_size != 0:
|
||||
raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}")
|
||||
|
|
@ -86,8 +86,10 @@ class ColumnParallelLinear(ParallelModel):
|
|||
def forward(self, input: Tensor) -> Tensor:
|
||||
output = F.linear(input, self.weight, self.bias)
|
||||
|
||||
if self.reduce_results:
|
||||
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
||||
if self.gather_results:
|
||||
output_list = [torch.empty_like(output) for _ in range(self.world_size)]
|
||||
dist.all_gather(output_list, output, group=self.process_group)
|
||||
output = torch.cat(output_list, dim=-1)
|
||||
|
||||
return output
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue