From fb85aaf6a6280baffbf8b51b024f8349f6956df6 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 21 Nov 2025 13:37:08 +0800 Subject: [PATCH] =?UTF-8?q?fix(parallel):=20=E4=BF=AE=E6=94=B9=E5=88=97?= =?UTF-8?q?=E5=B9=B6=E8=A1=8C=E7=BA=BF=E6=80=A7=E5=B1=82=E7=BB=93=E6=9E=9C?= =?UTF-8?q?=E8=81=9A=E5=90=88=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/model/parallel.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/khaosz/model/parallel.py b/khaosz/model/parallel.py index bf86388..c888304 100644 --- a/khaosz/model/parallel.py +++ b/khaosz/model/parallel.py @@ -68,14 +68,14 @@ class ColumnParallelLinear(ParallelModel): in_features: int, out_features: int, bias: bool = True, - reduce_results: bool = True + gather_results: bool = True ): super().__init__(process_group) self.in_features = in_features self.out_features = out_features self.out_features_per_rank = out_features // self.world_size - self.reduce_results = reduce_results + self.gather_results = gather_results if out_features % self.world_size != 0: raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}") @@ -86,8 +86,10 @@ class ColumnParallelLinear(ParallelModel): def forward(self, input: Tensor) -> Tensor: output = F.linear(input, self.weight, self.bias) - if self.reduce_results: - dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group) + if self.gather_results: + output_list = [torch.empty_like(output) for _ in range(self.world_size)] + dist.all_gather(output_list, output, group=self.process_group) + output = torch.cat(output_list, dim=-1) return output