nn.Conv2d has a parameter
groupscontrols the connections between inputs and outputs.
out_channelsmust both be divisible by
groups. For example,
- At groups=1, all inputs are convolved to all outputs.
- At groups=2, the operation becomes equivalent to having two conv layers side by side
I think, something similar makes sense for
nn.Linear. For example, it would have the advantage of reducing the number of parameters.
This is what I came up with:
class GroupedLinear(nn.Module): def __init__( self, in_features: int, out_features: int, bias: bool = True, groups=1, device=None, dtype=None, ) -> None: super().__init__() assert in_features % groups == 0 assert out_features % groups == 0 self.in_features = in_features self.out_features = out_features self.groups = groups self._linear_layers = nn.ModuleList( [ nn.Linear( in_features // groups, out_features // groups, bias=bias, device=device, dtype=dtype, ) for _ in range(groups) ] ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.view((x.shape, self.groups, -1)) result = [ l(x[:, i]) for i, l in enumerate(self._linear_layers) ] return torch.cat(result, dim=1)
Is there a better way?