I am trying to write a custom function to be executed to compute the gradient in the backward pass of my activation function. Specifically, I am trying to replace the striaght through estimator in the BinaryNet paper here, with my custom function. https://arxiv.org/abs/1602.02830
In the BinaryNet, they implement their own BinaryTanh neuron subclassed of nn.Module, in which they call hard tanh in the forward() function and then another neuron called Binarize. The Binarize neuron is subclassed from nn.Function, and computes sign in the forward() function and just returns the input in the backward() function. Thus, in the backward pass, they use the derivative of hard tanh, since the derivative of sign is 0 almost everywhere. This derivative process is taken care of by PyTorch automatic differentiation.
I am trying to not use the derivative of hard tanh in the backward pass, but rather my own function. How would I implement this? Does nn.Module have it’s own backward function? Would I have to call my own function in the forward pass as well, so automatic differentiation takes care of the derivative of it in the backward pass? If I did this, would it slow down the forward pass?
@smth Thanks for the response. I did do this. I found that if I implement the derivative of hard tanh manually in the backward pass, subclassed from function like you said, then I don’t get the same result as having automatic differentiation do it. In other words, the network doesn’t learn. Is there any particular reason for this? I implement the derivative in the backward function of nn.Function.
@smth Is calling hard tanh in the forward() function of nn.Module equivalent to calling the derivative of hard tanh in the backward() function of nn.Function? This is for with respect to back propagation and automatic differentiation, so the backward pass. I believe there must be some difference, but I don’t really understand what it is.
Also, I notice that the inputs to this backward() function subclassed from nn.Function are insanely small. Is that supposed to be the case?
I am using the hard tanh as an activation function.
Maybe this is useful for you?
I implemented a custom backward function for a ternary network (I read the BinaryNet paper that you cited before also), but to do this, I faced similar problems.
Eventually, breaking your function down into two pieces will solve your problems, too? At least in my case, it is working fine.