Apply functional.conv2d across minibatch dimension

Hello all,

I have a network that generates kernel weights for a 2D convolution operation. It takes a single input and generates a weight vector which then reshaped into KxK kernel where K is the kernel size. Finally, I apply this kernel to an image. I’m trying to make this work with batches. So when the network generates a [N, KxK] weight vector, I would like to have a tensor with size [N, K, K] that is applied to [N, 1, H, W] images. Here, N is the batch size, H and W are the image sizes. I can do this with for loop over the batches but I’m wondering if there is a more efficient way. Below is my network:

class FilterNet(nn.Module):

    def __init__(self, obs_size, hidden_size, kernel_size, activation_func='tanh'):
        super(FilterNet, self).__init__()
        self.kernel_size = kernel_size

        self.fc_obs1 = nn.Linear(obs_size, hidden_size)
        self.fc_obs2 = nn.Linear(hidden_size, hidden_size)
        self.fc_obs3 = nn.Linear(hidden_size, hidden_size)
        self.fc_obs4 = nn.Linear(hidden_size, kernel_size * kernel_size)

        self.act_func = getattr(torch, activation_func)

    def forward(self, images, tau):
        hidden = self.act_func(self.fc_obs1(tau))
        hidden = self.act_func(self.fc_obs2(hidden))
        hidden = self.act_func(self.fc_obs3(hidden))
        kernel_weigths = self.fc_obs4(hidden)

        heatmaps = []
        for weights_i, image_i in zip(kernel_weigths.chunk(kernel_weigths.size(0), dim=0), images.chunk(images.size(0), dim=0)):
            kernel = weights_i.view(1, 1, self.kernel_size, self.kernel_size)
            heatmap = F.conv2d(image_i, kernel, padding=2)
            heatmap = heatmap.squeeze()
            heatmaps += [heatmap]
        heatmaps = torch.stack(heatmaps, 0)

        # kernel_weigths = kernel_weigths.view(kernel_weigths.size(0), 1, self.kernel_size, self.kernel_size)
        # heatmaps = F.conv2d(images, kernel_weigths, padding=2)
        
        return heatmaps

Note the commented section of the code does not work.

Thanks!

1 Like

Since you are using single channel images, you could swap the batch and the channel dimension and use a grouped convolution:

class FilterNet(nn.Module):

    def __init__(self, obs_size, hidden_size, kernel_size, activation_func='tanh'):
        super(FilterNet, self).__init__()
        self.kernel_size = kernel_size

        self.fc_obs1 = nn.Linear(obs_size, hidden_size)
        self.fc_obs2 = nn.Linear(hidden_size, hidden_size)
        self.fc_obs3 = nn.Linear(hidden_size, hidden_size)
        self.fc_obs4 = nn.Linear(hidden_size, kernel_size * kernel_size)

        self.act_func = getattr(torch, activation_func)

    def forward(self, images, tau, use_loop=False):
        hidden = self.act_func(self.fc_obs1(tau))
        hidden = self.act_func(self.fc_obs2(hidden))
        hidden = self.act_func(self.fc_obs3(hidden))
        kernel_weigths = self.fc_obs4(hidden)

        if use_loop:
            heatmaps = []
            for weights_i, image_i in zip(kernel_weigths.chunk(kernel_weigths.size(0), dim=0), images.chunk(images.size(0), dim=0)):
                kernel = weights_i.view(1, 1, self.kernel_size, self.kernel_size)
                heatmap = F.conv2d(image_i, kernel, padding=2)
                heatmap = heatmap.squeeze()
                heatmaps += [heatmap]
            heatmaps = torch.stack(heatmaps, 0)
        else:
            kernel_weigths = kernel_weigths.view(kernel_weigths.size(0), 1, self.kernel_size, self.kernel_size)
            images = images.permute(1, 0, 2, 3)
            heatmaps = F.conv2d(images, kernel_weigths, padding=2, groups=images.size(1))
            
        return heatmaps


model = FilterNet(10, 10, 3)
images = torch.randn(16, 1, 24, 24)
tau = torch.randn(16, 10)

output_loop = model(images, tau, use_loop=True)
output_grouped = model(images, tau, use_loop=False)

print((output_loop - output_grouped).max())
> tensor(2.3842e-07, grad_fn=<MaxBackward1>)

This will apply each kernel in a single input channel, so that in fact we are applying each kernel on a single sample from the batch.

2 Likes

Thanks a lot, @ptrblck. This is exactly what I wanted.
Just did a quick comparison and it seems that the grouped convolution is a lot faster than for loop.
A single forward pass used to take 0.01 seconds now it’s 0.0002 seconds.

1 Like

@ptrblck I would like to do the same thing with multi-channel inputs and kernels now.
So I have a tensor with shape [N, C, W, H] as the input and a tensor [N, K, kW, kH] where N is the batch size, C is the image channels, and K is the number of kernels and K=C. I want first K filters from the kernel batch to be applied to first image with C channels from the image batch and so on. Each k filter should be only applied one channel c. In the output I should have a feature map with size [N, C, W, H].
Can I do this without a for loop too?