How to define a new layer with autograd?

Hi all,

I am new to pytorch and also new to autograd. My question is : Do I need to compute the partial derivatives for my functions parameters?
For example: My new layer want to compute a 1-d gaussian probability density value, the function is,
f(x)=a * exp^((x-b)^2 / c)) ,
where a,b,c are the parameters need to be updated. I think all of these operations are basic operation and the output is a scalar. Do I still need to write backward code like we do in Caffe? Or May I just define a new module only with forward function then pytorch will compute the parameters’ derivatives automatically for me?

3 Likes

Sure, this will be handled for you. For example:

import torch.nn as nn
from torch.autograd import Variable

class Gaussian(nn.Module):
    def __init__(self):
        self.a = nn.Parameter(torch.zeros(1))
        self.b = nn.Parameter(torch.zeros(1))
        self.c = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        # unfortunately we don't have automatic broadcasting yet
        a = self.a.expand_as(x)
        b = self.b.expand_as(x)
        c = self.c.expand_as(x)
        return a * torch.exp((x - b)^2 / c)

module = Gaussian()
x = Variable(torch.randn(20))
out = module(x)
loss = loss_fn(out)
loss.backward()

# Now module.a.grad should be non-zero.
11 Likes

Thank you so much for your kind help! I still would like to ask few questions:
(1) should we add super(Gaussian, self).__init__() after def __init__(self):?
(2) After we add this new module to our network and start to train it, Do I only need to use

optimizer = optim.SGD(net.parameters(), lr = 0.01) optimizer.zero_grad() output = net(input) loss = criterion(output, target) loss.backward() optimizer.step()

to update all the parameters?

  1. My bad, I have typed the example quickly just to give you a sense of how the code should look like, so I forgot about the super call. Nice catch!
  2. Yes, that should do it. Since a, b and c are nn.Parameter objects, that were assigned directly to a module, they should appear in .parameters() of the main container.
1 Like

Thank you so much for your help!

Now we have defined the Gaussian layer, I want to know how to use it in a network of several layers. For example,

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc1 = nn.Linear(100, 20)
        self.gauss = Gaussian()
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.gauss(x)

net = MyNet()

Is the above code correct? Or do we have to set up other things in order to use the newly-defined layer in other network?

the above is correct.

Can I define my own layer with a forward/backward function or I will have to define the layer as a Function so that I can define the backward function too?

You have no need to define the backward function yourself as long as the functions you are using to calculate the loss function is in PyTorch scope. You should definitely define a forward function just as shown by @jdhao.

By the way if we implement a layer and we have scipy or numpy operation inside of it Like in here, can it be run and accelerate on GPU?, or the layer just run on CPU because our numpy and scipy can not run on GPU?

1 Like

you can add .cuda() to both the module and the input to see if the given example can run without any error.

@jdhao Yes it works I have tried it before. However I just want all operation works on GPU. In that example I think (please correct me if I am wrong) it still can run but the operation which run on GPU is just tensor operation like in here or some basic pytorch function. So when we use .cuda() it will have big communication cost between GPU and CPU. For example when I am trying manual convolution using python loop operation to pytorch tensor it needs so much time, compare with ready made pytorch convolution. I just want to know what I miss in extending pytorch, whether all operation even when using scipy declared in pytorch can really work on GPU or scipy process still works on CPU but the tensor on GPU which give big communication cost between CPU and GPU.

1 Like

Most of the time we do not need to extend PyTorch using numpy. As long as you use the builtin method of Variable, you can only write forward method and backward gradient computation is handled by autograd. So using a composition of builtin Variable method to achieve what you want is more time-saving.

@herleeyandi, the portions of your code which use scipy will not be GPU-accelerated. Only operations on CUDA Torch tensors and Variables will be GPU-accelerated.

2 Likes

@colesbury I see, so how about if I want to create some functions which is GPU accelerated?, should I use CFFI which coded using CUDA C++ ?, Can you give me more hint :slight_smile:

Yes, you can do that. Writing new CUDA kernels usually requires a lot of effort. If you can express your layer in terms of existing Tensor operations, then that’s usually a better way to get started. If you can’t do that, then you might have to write new kernels.

1 Like

What if the forward function is not differentiable? How should I update the parameters and return the gradient w.r.t inputs?

I tried the piece of code posted by apaszke, but I got the following error.

RuntimeError: bitwise_xor(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

So I am confused how @big_tree and @jdhao managed to get this code to work.

I read the documentation Extending PyTorch, but the extending torch.nn (adding a module) section using the LinearFunction defined above, which explicitly defines both forward and backward functions, but this doesn’t seem to be the case in apaszke 's sample code above. So there is a way to make his code work, without explicitly encoding the gradient myself?

I figured that the problem arises from the power calculation in python. I changed ^2 to **2 and it works now.

1 Like