How to write an new convolution operator in pytorch

We all know that convolution operator is just a linear operator acting in sliding window way.

I am trying to implement SimNet Convolotion, which apply merely lp norm before the linear operation.

Or we can say, the original 1d convolution is:

And I want:

What’s the best way to do in pytorch?

Wouldn’t something like this do the trick:

class CustomConv(torch.nn.Module):
    def __init__(p, *args, **kwargs):
        super().__init__(self)
        self.conv = torch.nn.Conv2d(*args, **kwargs)
        self.p = p

    def forward(self, x, z):
        diff = x-z
        p_norm = torch.abs(x-z).pow(self.p)
        return self.conv(p_norm) 

There might be some typos or little mistakes as I’m typing from my smartphone.

Thank you for your reply.
But…no
The z in there has the same size with self.conv.weight
That’s the problem. :joy:

But if z is the same size as the convolutional weight x must be of equal size to calculate a norm or am I mistaken there?

Maybe I was not so clear.
I said “in the sliding window way” means, convolution operate take a patch of x to do the linear operation.
Looks like:

Every point of the output feature map is got from a patch of x. Note x_patch here.
Now, the lp norm is also implemented in x_patch.

Or we can say, the original 1d convolution is:

And I want:

I’m afraid you have to implement the convolution yourself in python using the already available pytorch functions as the Convolutions are implemented in C.

Alternatively you could have a look at the C-Implementation which can be found here and check whether it would be easier to modify this implementation. On this tutorial page you can find examples towards how to extend pytorch.

I think the best approach would be to write your own autograd function as suggested in this post

2 Likes

Thank you for your advice.
The C-implementation way is little tricky, for I found the pytorch->torch->Aten->CUDNN dependency chain,:joy:, which I am not good at.
The autograd way looks easy. But I have to use for loop to implement it. Well, I will try, thank you.