Bypass individual channels in nn.Conv2d

In a recent paper https://arxiv.org/abs/2003.11883 the authors are talking about forwarding only 25% of random sampled channels of a featuremap through a conv-layer while “bypassing” the others to reduce the memory footprint.

Here is the corresponding section in that paper:

How would you achieve something like this in PyTorch?

I figured you could achieve this with something like in the following example, but maybe there is another simpler and obvious way that I am missing?

class SampleConv(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super().__init__(
            in_channels, out_channels,
            kernel_size, stride=stride, padding=padding, dilation=dilation,
            groups=groups, bias=bias)

        self.channel_sample_ratio = 0.25

    def forward(self, input):
        sampled_channels = torch.rand(x.size(1), device=x.device, requires_grad=False)

        weight = self.weight[:, sampled_channels < self.channel_sample_ratio]
        bias = self.bias

        y = F.conv2d(input[:, sampled_channels < self.channel_sample_ratio], weight, bias, self.stride, self.padding, self.dilation, self.groups)
        return y

Also I measured the memory consumption of above SampleConv and its regular nn.Conv2d-equivalent and it doen’t seem like that it requires 75% less memory.

I’m not sure I understand the section correctly, but it seems that 25% of the input channels is randomly sampled and passed to the conv layer (or "mixed transformation of O?), while the remaining channels are “bypassed”?

Does the second part mean that 75% of the input channels are just concatenated to the output of the conv operaration?

I think you can assume that input_channels = output_channels so the addition in eq. 7 could make sense (instead of concatenation).

But apart from that, I understood the section in the same way as you.

Disabling gradients for certain channels is not possible, correct?

The number of output channels is defined by the number of kernels.
If you want to remove 75% of the input channels and would like to keep the same amount of output channels, you would just remove kernels.

I might be wrong, but that’s not my understanding of the section.

You could zero out the gradients for specific channels.

So you would also do something like my SampleConv approach from my first post i.e. selecting certain kernels? Would it make sense to also select only output channels?

And yes, at second glance I think you are totally right. Addition makes no sense at all so concatenation might be the way to go.

Unfortunately I still can’t wrap my head around their claimed 75% decrease in memory consumption.

In your code you are slicing the input channels, i.e. you are making the kernels “smaller”:

self.weight[:, sampled_channels < self.channel_sample_ratio]

The weight parameter is defined as [out_channels, in_channels, height, width].

I don’t understand the posted section as removing kernels, but the input channels and bypassing the removed channels (i.e. not applying the convolution on them). But I’m still unsure. :wink:

That’s also, why I’m thinking, that I misunderstand the section, so let’s better wait for smarter users who could help us out. :slight_smile:

:smiley: Yep, lets wait for them. I also mailed the authors so maybe they will answer me.

1 Like

Please update this thread once you get an answer, as I’m curious how it should be implemented.

1 Like

Unfortunately, I didn’t receive an answer yet. In general it’s really a pity that in academia it is so hard to get an answer from the authors. /endofrant

Looking at fig 1. from that paper I am now sure that concatenation is meant:

Those 3x3, 5x5, 7x7 sep. convs are operations from their NAS search space. So as you can see, in this example two feature maps are sampled from the input and fed through those operations while the other ones are bypassed and concatenated to the output.

Therefore, I came up with this “solution”:

class SampleConv(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, sampling_ratio=0.25, stride=1, padding=0, dilation=1,
                 groups=1, bias=True):
        super().__init__(
            in_channels, out_channels,
            kernel_size, stride=stride, padding=padding, dilation=dilation,
            groups=groups, bias=bias)

        self.channel_sample_ratio = sampling_ratio

    def forward(self, x, sample=True):
        if sample:
            sampled_channels = torch.rand(x.size(1), device=x.device, requires_grad=False)

            weight = self.weight[:, sampled_channels < self.channel_sample_ratio]
            weight = weight[sampled_channels < self.channel_sample_ratio]
            if self.bias is not None:
                bias = self.bias[sampled_channels < self.channel_sample_ratio]
            else:
                bias = None

            x[:, sampled_channels < self.channel_sample_ratio] = F.conv2d(
                x[:, sampled_channels < self.channel_sample_ratio], weight, bias, self.stride, self.padding,
                self.dilation, self.groups)
        else:
            x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
        return x

I measured the GPU memory consumption and it requires 50%-60% less memory as the regular nn.Conv2d counterpart. (Which is still a far cry from the stated 75%).

Also the L1-Error on estimating a torch.randn tensor is slightly higher (0.83 vs 0.799) on convergence than with a regular nn.Conv2d:

sample_conv = SampleConv(320, 320, 1, sampling_ratio=0.25).cuda()
normal_conv = nn.Conv2d(320, 320, 1).cuda()

optim_sample_conv = torch.optim.SGD(sample_conv.parameters(), lr=0.1, momentum=0.9)
optim_normal_conv = torch.optim.SGD(normal_conv.parameters(), lr=0.1, momentum=0.9)
target = torch.randn(1, 320, 48, 48).cuda()

for e in range(10000):
    # sample conv
    input_t = torch.ones(1, 320, 48, 48).cuda()
    y = sample_conv(input_t)
    loss = (y - target).abs().mean()
    optim_sample_conv.zero_grad()
    loss.backward()
    optim_sample_conv.step()

    with torch.no_grad():
        y = sample_conv(input_t, sample=False)
        l1_error = (y - target).abs().mean()

    print("SampleConv: Loss: {} L1-Error: {}".format(loss, l1_error))

    # regular nn.Conv2d
    input_t = torch.ones(1, 320, 48, 48).cuda()
    y = normal_conv(input_t)
    loss = (y - target).abs().mean()
    optim_normal_conv.zero_grad()
    loss.backward()
    optim_normal_conv.step()

    with torch.no_grad():
        y = normal_conv(input_t)
        l1_error = (y - target).abs().mean()

    print("nn.Conv2d: Loss: {} L1-Error: {}".format(loss, l1_error))

What do you think?

Yeah, the figure explains it pretty well. Thanks for sharing.

I think the general idea looks alright, and I assume you are using grouped convolutions, as you are slicing the weight in two dimensions?

No, I didn’t use grouped convolutions yet. What would be the motivation behind this?

Another thing that I find confusing is that they use InvertedResidual Blocks (MobileNetV3) which look like input_channels --> hidden_channels --> hidden_channels --> output_channels. But my implementation from above requires input_channels == output_channels so I can’t actually use it for the MBConv-Block (input_channels != hidden_channels).

Edit: I think we should only use this SampleConv on the middle conv (hidden -> hidden) of the InvertedResidual Block. The in- and out-convolutions are just 1x1 and not that expensive.

From the figure it also seems like that the sampled channels are only concatenated and not “inplace-modified” like in my implementation. Wouldn’t that be very bad for succeeding layers as they see different types of feature-maps (produced with different kernels) at the same position?