I think you can use the
group parameter from Conv2d.
groups controls the connections between inputs and outputs. in_channels and out_channels must both be divisible by groups.
At groups=1, all inputs are convolved to all outputs.
At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated. At groups=
in_channels, each input channel is convolved with its own set of filters (of size out_channels // in_channels).