Implementing backward function nn.Module

Hello,

I am trying to write a custom function to be executed to compute the gradient in the backward pass of my activation function. Specifically, I am trying to replace the striaght through estimator in the BinaryNet paper here, with my custom function.
https://arxiv.org/abs/1602.02830

In the BinaryNet, they implement their own BinaryTanh neuron subclassed of nn.Module, in which they call hard tanh in the forward() function and then another neuron called Binarize. The Binarize neuron is subclassed from nn.Function, and computes sign in the forward() function and just returns the input in the backward() function. Thus, in the backward pass, they use the derivative of hard tanh, since the derivative of sign is 0 almost everywhere. This derivative process is taken care of by PyTorch automatic differentiation.

I am trying to not use the derivative of hard tanh in the backward pass, but rather my own function. How would I implement this? Does nn.Module have it’s own backward function? Would I have to call my own function in the forward pass as well, so automatic differentiation takes care of the derivative of it in the backward pass? If I did this, would it slow down the forward pass?

Thanks

You’ll have to write your own autograd.Function. An example is given here: http://pytorch.org/docs/master/notes/extending.html

@smth Thanks for the response. I did do this. I found that if I implement the derivative of hard tanh manually in the backward pass, subclassed from function like you said, then I don’t get the same result as having automatic differentiation do it. In other words, the network doesn’t learn. Is there any particular reason for this? I implement the derivative in the backward function of nn.Function.

Hmmm, I’m not sure. You can double-check whether you implemented your gradient correctly using the gradcheck helper function.
http://pytorch.org/docs/master/notes/extending.html?highlight=gradcheck

@smth Is calling hard tanh in the forward() function of nn.Module equivalent to calling the derivative of hard tanh in the backward() function of nn.Function? This is for with respect to back propagation and automatic differentiation, so the backward pass. I believe there must be some difference, but I don’t really understand what it is.

Also, I notice that the inputs to this backward() function subclassed from nn.Function are insanely small. Is that supposed to be the case?

I am using the hard tanh as an activation function.

Maybe this is useful for you?
I implemented a custom backward function for a ternary network (I read the BinaryNet paper that you cited before also), but to do this, I faced similar problems.
Eventually, breaking your function down into two pieces will solve your problems, too? At least in my case, it is working fine.

i am impementing random feedback alignment with sigmoidal activation function for the hidden layer.
(network is 2 layer model i.e 1 hidden layer)

I had to overide the backward method.

Is the code below correct?

@staticmethod
def backward(context, grad_output):

    input, weight, weight_rfa, bias = context.saved_variables
    grad_input= grad_weight = grad_weight_rfa = grad_bias = None
    
    if context.needs_input_grad[0]:
        grad_input = grad_output.mm(weight_rfa) 
                      
        
    if context.needs_input_grad[1]:
        grad_weight = grad_output.t().mm(input)

    if bias is not None and context.needs_input_grad[3]:
        grad_bias = grad_output.sum(0).squeeze(0)
        
    return grad_input, grad_weight, grad_weight_rfa, grad_bias