How to run functional.conv2d with different weights for each sample in batch?

I have different weights for each sample in batch, so, how can I run conv2d with custom weights for each sample on GPU simultaneously? According to functional.conv2d, it convolves the same weights on each input. The slower way I found is to run functional.conv2d in a for loop over each data sample.

Thanks

You could move the batch dimension of the input into the channels and use the groups argument of F.conv2d.

Can you elaborate? I can’t get it to work…

import torch
import torch.nn.functional as F

# input shape (N,C_in,H,W)
input = torch.zeros((4,8,9,16))

# weight shape (N,C_out,C_in,1,1)
weight = torch.zeros((4,5,8,1,1))

# new_shape (C_out, N*C_in, h_k, w_k)
new_shape = (weight.shape[1],weight.shape[0]*weight.shape[2]) + weight.shape[3:]
result = F.conv2d(input, weight.reshape(new_shape), bias=None, stride=1, groups=weight.shape[0])

# Both C_out AND C_in are divided by groups
"RuntimeError: Given groups=4, expected weight to be divisible by 4 at dimension 0, but got weight of size [[5, 32, 1, 1]] instead"

# new_shape (N*C_out, N*C_in, h_k, w_k)
new_shape = (weight.shape[0]*weight.shape[1],weight.shape[0]*weight.shape[2]) + weight.shape[3:]
result = F.conv2d(input, weight.reshape(new_shape), bias=None, stride=1, groups=weight.shape[0])

# Does not work, either (obviously)
"RuntimeError: shape '[20, 32, 1, 1]' is invalid for input of size 160"

What I want is this:

# Expected result shape (N,C_out,W,H)
result = torch.cat([F.conv2d(input[i].unsqueeze(0), w, bias=None, stride=1) for i,w in enumerate(weight)], dim=0)
result.shape

But it seems wasteful to run N separate convolutions:

The slower way I found is to run functional.conv2d in a for loop over each data sample.

Is there another way?

I have figured it out:


%%timeit
# Expected result shape (N,C_out,W,H)
expected_result = torch.cat([F.conv2d(input[i].unsqueeze(0), w, bias=None, stride=1) for i,w in enumerate(weight)], dim=0)

"854 µs ± 80.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)"

%%timeit
s_in = (1,-1,) + input.shape[2:]
s_w = (-1,) + weight.shape[2:]
s_out = (input.shape[0], weight.shape[1]) + input.shape[2:]
result = F.conv2d(input.reshape(s_in), weight.reshape(s_w), bias=None, stride=1, groups=weight.shape[0]).reshape(s_out)

"305 µs ± 52.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)"

torch.testing.assert_close(result, expected_result)

(Unscientific test using CUDA tensors on an Intel(R) Core™ i7-6700K CPU @ 4.00GHz and NVIDIA GeForce GTX TITAN X system.)

The key lies in modifying the input shape as well.

1 Like