nn.Conv3d: papping each input channel to the corresponding output channel without summation with other input channels

Hi,

I’m doing 3D convolutions on multi-channel MR images. I have equal number of input and output channels, and I want each input channel to be mapped to the corresponding output channel, without being summed with other input channels. I’m planning to use “groups” and I want to make sure it’s the correct strategy.

This block of code sums the input channels, which I don’t want:

class separateChannelsConv(nn.Module):
    
    def __init__(self, channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Conv3d(channels, channels, kernel_size, stride, padding)
        
    def forward(self, x):
        return self.conv(x)

This block of code doesn’t sum the input channels, which is ideal. I just want to make sure that a separate kernel is used to map each input channels to the corresponding output channels (as opposed to one kernel mapping all input channels to the corresponding output channels):

class separateChannelsConv(nn.Module):

    def __init__(self, channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Conv3d(channels, channels, kernel_size, stride, padding, groups=channels)

    def forward(self, x):
        return self.conv(x)

Please confirm that using “groups” is the right strategy to achieve my goal: to map each input channel to the corresponding output channel without summation with other input channels?

Thanks a lot,
Arman

Hi Arman!

Yes, setting groups = channels will do what you want.

Let’s illustrate this with a simple case:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> conv = torch.nn.Conv3d (3, 3, 1, groups = 3, bias = False)
>>> conv.weight[:, :, 0, 0, 0]
tensor([[-0.5339],
        [-0.6689],
        [-0.9413]], grad_fn=<SelectBackward>)
>>> conv (torch.tensor ([1.0, 0.0, 0.0]).reshape (1, 3, 1, 1, 1))
tensor([[[[[-0.5339]]],


         [[[ 0.0000]]],


         [[[ 0.0000]]]]], grad_fn=<MkldnnConvolutionBackward>)
>>> conv (torch.tensor ([0.0, 1.0, 0.0]).reshape (1, 3, 1, 1, 1))
tensor([[[[[ 0.0000]]],


         [[[-0.6689]]],


         [[[ 0.0000]]]]], grad_fn=<MkldnnConvolutionBackward>)
>>> conv (torch.tensor ([0.0, 0.0, 1.0]).reshape (1, 3, 1, 1, 1))
tensor([[[[[ 0.0000]]],


         [[[ 0.0000]]],


         [[[-0.9413]]]]], grad_fn=<MkldnnConvolutionBackward>)

Best.

K. Frank

1 Like

Thank you so much, Frank!