fix(parallel): 修改列并行线性层结果聚合方式
This commit is contained in:
parent
6fb6a15e81
commit
fb85aaf6a6
|
|
@ -68,14 +68,14 @@ class ColumnParallelLinear(ParallelModel):
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
reduce_results: bool = True
|
gather_results: bool = True
|
||||||
):
|
):
|
||||||
super().__init__(process_group)
|
super().__init__(process_group)
|
||||||
|
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.out_features_per_rank = out_features // self.world_size
|
self.out_features_per_rank = out_features // self.world_size
|
||||||
self.reduce_results = reduce_results
|
self.gather_results = gather_results
|
||||||
|
|
||||||
if out_features % self.world_size != 0:
|
if out_features % self.world_size != 0:
|
||||||
raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}")
|
raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}")
|
||||||
|
|
@ -86,8 +86,10 @@ class ColumnParallelLinear(ParallelModel):
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
output = F.linear(input, self.weight, self.bias)
|
output = F.linear(input, self.weight, self.bias)
|
||||||
|
|
||||||
if self.reduce_results:
|
if self.gather_results:
|
||||||
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
output_list = [torch.empty_like(output) for _ in range(self.world_size)]
|
||||||
|
dist.all_gather(output_list, output, group=self.process_group)
|
||||||
|
output = torch.cat(output_list, dim=-1)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue