Thanks for the detailed description!
I think the best shot at the problem would be to use the functional API and to recreate the patches as you need them.
Since you only want to keep the center pixel, you could probably just use indexing of the input tensor to get these values.
How would you get the outer values? Do you have some kind of mapping or would also indexing / gather work?
As you can see I’m quite unsure about the creation of these input patches, so feel free to give some (code) examples.
Once you have the input patches, you could use F.conv2d
and your weight
parameter to perform the vanilla convolution.