diff --git a/khaosz/model/parallel.py b/khaosz/model/parallel.py new file mode 100644 index 0000000..bf86388 --- /dev/null +++ b/khaosz/model/parallel.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from torch import Tensor +from typing import Dict + + +class ParallelModel(nn.Module): + def __init__(self, process_group: dist.ProcessGroup): + super().__init__() + self.process_group = process_group + self.rank = dist.get_rank(self.process_group) + self.world_size = dist.get_world_size(self.process_group) + + +class RowParallelLinear(ParallelModel): + def __init__( + self, + process_group: dist.ProcessGroup, + in_features: int, + out_features: int, + bias: bool = True, + reduce_results: bool = True + ): + super().__init__(process_group) + + self.in_features = in_features + self.out_features = out_features + self.in_features_per_rank = in_features // self.world_size + self.reduce_results = reduce_results + + if in_features % self.world_size != 0: + raise ValueError(f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}") + + self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank)) + self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None + + def forward(self, input: Tensor) -> Tensor: + output = F.linear(input, self.weight) + + if self.reduce_results: + dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group) + + if self.bias is not None: + output += self.bias + + return output + + def load_state_dict(self, state_dict: Dict[str, Tensor]): + full_weight = state_dict.get('weight') + full_bias = state_dict.get('bias') + + start_idx = self.rank * self.in_features_per_rank + end_idx = start_idx + self.in_features_per_rank + weight_slice = full_weight[:, start_idx:end_idx] + self.weight.data.copy_(weight_slice) + + if self.bias is not None: + self.bias.data.copy_(full_bias) + + +class ColumnParallelLinear(ParallelModel): + def __init__( + self, + process_group: dist.ProcessGroup, + in_features: int, + out_features: int, + bias: bool = True, + reduce_results: bool = True + ): + super().__init__(process_group) + + self.in_features = in_features + self.out_features = out_features + self.out_features_per_rank = out_features // self.world_size + self.reduce_results = reduce_results + + if out_features % self.world_size != 0: + raise ValueError(f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}") + + self.weight = nn.Parameter(torch.empty(self.out_features_per_rank, self.in_features)) + self.bias = nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None + + def forward(self, input: Tensor) -> Tensor: + output = F.linear(input, self.weight, self.bias) + + if self.reduce_results: + dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group) + + return output + + def load_state_dict(self, state_dict: Dict[str, Tensor]): + full_weight = state_dict.get('weight') + full_bias = state_dict.get('bias') + + start_idx = self.rank * self.out_features_per_rank + end_idx = start_idx + self.out_features_per_rank + weight_slice = full_weight[start_idx:end_idx, :] + self.weight.data.copy_(weight_slice) + + if self.bias is not None: + bias_slice = full_bias[start_idx:end_idx] + self.bias.data.copy_(bias_slice)