diff --git a/khaosz/model/parallel.py b/khaosz/model/parallel.py index bf86388..c888304 100644 --- a/khaosz/model/parallel.py +++ b/khaosz/model/parallel.py @@ -68,14 +68,14 @@ class ColumnParallelLinear(ParallelModel): in_features: int, out_features: int, bias: bool = True, - reduce_results: bool = True + gather_results: bool = True ): super().__init__(process_group) self.in_features = in_features self.out_features = out_features self.out_features_per_rank = out_features // self.world_size - self.reduce_results = reduce_results + self.gather_results = gather_results if out_features % self.world_size != 0: raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}") @@ -86,8 +86,10 @@ class ColumnParallelLinear(ParallelModel): def forward(self, input: Tensor) -> Tensor: output = F.linear(input, self.weight, self.bias) - if self.reduce_results: - dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group) + if self.gather_results: + output_list = [torch.empty_like(output) for _ in range(self.world_size)] + dist.all_gather(output_list, output, group=self.process_group) + output = torch.cat(output_list, dim=-1) return output