nn.Conv2d
has a parameter groups
:
groups
controls the connections between inputs and outputs.in_channels
andout_channels
must both be divisible bygroups
. For example,
- At groups=1, all inputs are convolved to all outputs.
- At groups=2, the operation becomes equivalent to having two conv layers side by side
I think, something similar makes sense for nn.Linear
. For example, it would have the advantage of reducing the number of parameters.
This is what I came up with:
class GroupedLinear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
groups=1,
device=None,
dtype=None,
) -> None:
super().__init__()
assert in_features % groups == 0
assert out_features % groups == 0
self.in_features = in_features
self.out_features = out_features
self.groups = groups
self._linear_layers = nn.ModuleList(
[
nn.Linear(
in_features // groups,
out_features // groups,
bias=bias,
device=device,
dtype=dtype,
)
for _ in range(groups)
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.view((x.shape[0], self.groups, -1))
result = [
l(x[:, i])
for i, l in enumerate(self._linear_layers)
]
return torch.cat(result, dim=1)
Is there a better way?