Convolution that takes a function as kernel

I am currently experimenting with an idea that would require me to have a “dynamic” kernel (a kernel that changes with the input).

So for each input “patch”, I would have a function f (a simple MLP) that produces the desired filter for this specific part of the image, but it seems the convolution-operator only takes a static filter.

How can I achieve this?



There is no builtin way to do this.
But you can do it by using multiple convs:
if you want your filter to be a MLP that takes an chan_in x in x in patch, then the hidden layer has hidden size and then output chan_out values so that the whole thing transforms an image of size Batch x chan x in x in -> Batch x chan_out x 1 x 1 (increasing the size in will increase the output size from 1 to something else depending on stride/padding).
You can use 2 conv with the following parameters:

  • first convolution will change the channels from chan_in to hidden and kernel size will be in x in with the stride and padding of the original conv
  • Add a Relu or whatever non-linearity you want here.
  • second conv will change channels from hidden to chan_out and kernel size will be 1 x 1, stride 1 and no padding.

I haven’t tested so some numbers might be off, but that should work :slight_smile:

1 Like

thanks, interesting idea. But, if I understood you right, this will only evaluate f and not apply the filter. My function f doesn’t produce the result, but the filter for the convolution for the specific input. I don’t think I can use a convolution to apply the f’s produced filter, because with a custom kernel I can only add stuff together.

Hooo sorry I misread your question, I though you wanted the weights to be an MLP.
You want for every patch to have an MLP to generates the weights, and then apply these weights to this patch?

In that case you will need to use unfold. From the example in the doc, you will need to generate w from inp_unf which contains every patch (L such patches). And since you want one weight per patch, your weights will be (N, patch_size, L, chan_out). Then replace the matmul that does the conv by an element wise multiplication after expanding inp_unf and accumulate for each batch.

This might not make sense so here is a small sample based on the unfold example (same notations as the ones introduced in the doc for unfold):

import torch 

inp = torch.randn(1, 3, 10, 12)
w = torch.randn(2, 3, 4, 5)

# Original conv
print("Original Conv")
inp_unf = torch.nn.functional.unfold(inp, (4, 5))
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
# or equivalently (and avoiding a copy),
# out = out_unf.view(1, 2, 7, 8)
print((torch.nn.functional.conv2d(inp, w) - out).abs().max())

# Custom conv
print("Custom Conv")
def f(inp_unf, chan_out):
    # Input: (N, L, patch_size) that contains every single input patch
    # Output: (N, L, chan_out * patch_size) that contains the weights that will be used for every patch

    # Here you can have an MLP that has patch_size input features and chan_out * patch_size output features.

    # For simplicity (and check) we just expand the original weights here:
    output = w.view(-1).unsqueeze(0).unsqueeze(0)
    out_size = list(inp_unf.size())
    out_size[-1] *= chan_out
    return output.expand(*out_size)

inp_unf = torch.nn.functional.unfold(inp, (4, 5))
full_weights = f(inp_unf.transpose(1, 2), w.size(0))
# Reshape full_weights to the expected shape
full_weights = full_weights.view(inp_unf.size(0), inp_unf.size(2), w.size(0), inp_unf.size(1)).permute(0, 3, 1, 2)
# Compute the product weight*entry in each patch
full_out = inp_unf.unsqueeze(-1).expand_as(full_weights) * full_weights
# Sum over patches
out_unf = full_out.sum(1)
# Put chan dim at the right place
out_unf = out_unf.transpose(1, 2)
out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
# or equivalently (and avoiding a copy),
# out = out_unf.view(1, 2, 7, 8)
print((torch.nn.functional.conv2d(inp, w) - out).abs().max())


  • The original conv ops are much more optimized than that so even the unfold/matmul/fold version will be slower than conv2d
  • There are a lot of LARGE intermediary matrices here, the memory requirement for the autograd is going to be quite large. checkpointing might help if you really need to do this.
  • The MLP within f will do a mapping from patch_size features to chan_out*patch_size features which should be fairly small.
  • This same MLP will work with a batch size of N*L. Where L is given by the formula in the link for the unfold. This will be HUGE and so be careful what you do here as this can because very expensive (both in terms of runtime and memory) very quickly.

Hope this helps :slight_smile:

1 Like

hmm thanks for the idea :slightly_smiling_face:
I will try & see whether it’s working good enough. Is there a way to extend pytorch to provide this functionality (with reasonable effort)?

Not really. The convolutions are actually done (with some minor optimization) exactly as in the first part of the sample above. This means that there is never a point in the code where you look at a single patch at the time.

Assuming the speed of the naive solution would not suffice, how difficult would you think implementing/extending pytorch would be?


I am not sure if it would be possible to implement it much better.
If you use convolutions as MM, then it will be just saving a bit of memory by wrapping the above code in a Function.
If you want to use other algorithms, I am not even sure it is possible to do it.