`groups` parameter for `nn.Linear`

nn.Conv2d has a parameter groups:

groups controls the connections between inputs and outputs. in_channels and out_channels must both be divisible by groups. For example,
- At groups=1, all inputs are convolved to all outputs.
- At groups=2, the operation becomes equivalent to having two conv layers side by side

I think, something similar makes sense for nn.Linear. For example, it would have the advantage of reducing the number of parameters.

This is what I came up with:

class GroupedLinear(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        groups=1,
        device=None,
        dtype=None,
    ) -> None:
        super().__init__()

        assert in_features % groups == 0
        assert out_features % groups == 0

        self.in_features = in_features
        self.out_features = out_features
        self.groups = groups

        self._linear_layers = nn.ModuleList(
            [
                nn.Linear(
                    in_features // groups,
                    out_features // groups,
                    bias=bias,
                    device=device,
                    dtype=dtype,
                )
                for _ in range(groups)
            ]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view((x.shape[0], self.groups, -1))
        result = [
            l(x[:, i])
            for i, l in enumerate(self._linear_layers)
        ]
        return torch.cat(result, dim=1)

Is there a better way?