Hooo sorry I misread your question, I though you wanted the weights to be an MLP.
You want for every patch to have an MLP to generates the weights, and then apply these weights to this patch?
In that case you will need to use unfold. From the example in the doc, you will need to generate w
from inp_unf
which contains every patch (L
such patches). And since you want one weight per patch, your weights will be (N, patch_size, L, chan_out). Then replace the matmul
that does the conv by an element wise multiplication after expanding inp_unf
and accumulate for each batch.
This might not make sense so here is a small sample based on the unfold example (same notations as the ones introduced in the doc for unfold):
import torch
inp = torch.randn(1, 3, 10, 12)
w = torch.randn(2, 3, 4, 5)
# Original conv
print("Original Conv")
inp_unf = torch.nn.functional.unfold(inp, (4, 5))
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
# or equivalently (and avoiding a copy),
# out = out_unf.view(1, 2, 7, 8)
print((torch.nn.functional.conv2d(inp, w) - out).abs().max())
# Custom conv
print("Custom Conv")
def f(inp_unf, chan_out):
# Input: (N, L, patch_size) that contains every single input patch
# Output: (N, L, chan_out * patch_size) that contains the weights that will be used for every patch
# Here you can have an MLP that has patch_size input features and chan_out * patch_size output features.
# For simplicity (and check) we just expand the original weights here:
output = w.view(-1).unsqueeze(0).unsqueeze(0)
out_size = list(inp_unf.size())
out_size[-1] *= chan_out
return output.expand(*out_size)
inp_unf = torch.nn.functional.unfold(inp, (4, 5))
full_weights = f(inp_unf.transpose(1, 2), w.size(0))
# Reshape full_weights to the expected shape
full_weights = full_weights.view(inp_unf.size(0), inp_unf.size(2), w.size(0), inp_unf.size(1)).permute(0, 3, 1, 2)
# Compute the product weight*entry in each patch
full_out = inp_unf.unsqueeze(-1).expand_as(full_weights) * full_weights
# Sum over patches
out_unf = full_out.sum(1)
# Put chan dim at the right place
out_unf = out_unf.transpose(1, 2)
out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
# or equivalently (and avoiding a copy),
# out = out_unf.view(1, 2, 7, 8)
print((torch.nn.functional.conv2d(inp, w) - out).abs().max())
Disclaimer:
- The original conv ops are much more optimized than that so even the unfold/matmul/fold version will be slower than conv2d
- There are a lot of LARGE intermediary matrices here, the memory requirement for the autograd is going to be quite large. checkpointing might help if you really need to do this.
- The MLP within f will do a mapping from
patch_size
features to chan_out*patch_size
features which should be fairly small.
- This same MLP will work with a batch size of
N*L
. Where L is given by the formula in the link for the unfold. This will be HUGE and so be careful what you do here as this can because very expensive (both in terms of runtime and memory) very quickly.
Hope this helps