fix(parallel): 修改列并行线性层结果聚合方式

This commit is contained in:
ViperEkura 2025-11-21 13:37:08 +08:00
parent 6fb6a15e81
commit fb85aaf6a6
1 changed files with 6 additions and 4 deletions

View File

@ -68,14 +68,14 @@ class ColumnParallelLinear(ParallelModel):
in_features: int, in_features: int,
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
reduce_results: bool = True gather_results: bool = True
): ):
super().__init__(process_group) super().__init__(process_group)
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.out_features_per_rank = out_features // self.world_size self.out_features_per_rank = out_features // self.world_size
self.reduce_results = reduce_results self.gather_results = gather_results
if out_features % self.world_size != 0: if out_features % self.world_size != 0:
raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}") raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}")
@ -86,8 +86,10 @@ class ColumnParallelLinear(ParallelModel):
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
output = F.linear(input, self.weight, self.bias) output = F.linear(input, self.weight, self.bias)
if self.reduce_results: if self.gather_results:
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group) output_list = [torch.empty_like(output) for _ in range(self.world_size)]
dist.all_gather(output_list, output, group=self.process_group)
output = torch.cat(output_list, dim=-1)
return output return output