Custom binarization layer with straight through estimator gives error

I made a custom binarization layer following XNOR-Net implementation.

I made the following BinActiveZ layer as an equivalent of the one specified in their code

   class BinaryLayer(Function):

        def forward(self, input):
            return torch.sign(input)

        def backward(self, gradOutput):
            # print("Grad output is: ", gradOutput)
            gradinput = gradOutput.clamp_(-1, 1)
            # print("Grad input is:", gradinput)
            return gradinput

While running it after adding it into my model, I get the following error:
File “/usr/local/lib/python3.5/dist-packages/torch/autograd/variable.py”, line 146, in backward
self._execution_engine.run_backward((self,), (gradient,), retain_variables)
RuntimeError: could not compute gradients for some functions

How do I solve this issue?

1 Like

You should directly use torch.sign into the forward of your module, since its backward is already implemented:

class Net(nn.Module):
    def __init__(self):
        [...]

    def forward(self, input):
        x = layer1(input)
        # binary layer step:
        x = torch.sign(x)
        [...]
        return x
1 Like

I believe the backward pass for torch.sign() returns all gradients simply as zeros and not the output expected from a straight through estimator.
Here is the code and outputs to reproduce this:

import torch
import torch.nn as nn
from torch.autograd import Variable, Function
 
class BinaryLayer(nn.Module):
    def forward(self, input):
        return torch.sign(input)
 
input = torch.randn(4,4)
input = Variable(input, requires_grad=True)
 
model = BinaryLayer()
output = model(input)
loss = output.mean()
 
>>> loss.backward()
>>> input
Variable containing:
-1.4272  1.5698  2.6661  0.4438
 0.4978  0.8987  1.6969  0.2067
 0.3880 -2.1434 -1.1588 -0.5567
-1.2435 -0.1010  0.7215 -0.9209
[torch.FloatTensor of size 4x4]
 
>>> input.grad
Variable containing:
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
[torch.FloatTensor of size 4x4]

Hence, I implemented this layer. It works when used in isolation. Here is a snippet to verify it:

class BinaryLayer(Function):
    def forward(self, input):
        return torch.sign(input)
 
    def backward(self, grad_output):
        return grad_output.clamp_(-1, 1)
 
input = torch.randn(4,4)
input = Variable(input, requires_grad=True)
 
model = BinaryLayer()
output = model(input)
loss = output.mean()
>>> loss.backward()
>>> input
Variable containing:
 0.1690 -0.0028  1.4472 -0.1484
-0.6580  0.9200  0.5465  0.3896
-2.1211 -1.4266 -0.9634 -1.3991
-1.4426  1.4950  0.9849  0.4504
[torch.FloatTensor of size 4x4]
 
>>> input.grad
Variable containing:
1.00000e-02 *
  6.2500  6.2500  6.2500  6.2500
  6.2500  6.2500  6.2500  6.2500
  6.2500  6.2500  6.2500  6.2500
  6.2500  6.2500  6.2500  6.2500
[torch.FloatTensor of size 4x4]

However, I am unable to integrate this into a model. How do I go about this?

Maybe you can do

z = x +x.sign().detach() - x.detach() ?
which is x.sign() in forward and x in backward

9 Likes

But the thing is, in the backward pass I want my layer to return the gradients of x clamped between -1 and +1 and not just gradients of x itself :confused:

1 Like

I believe the backward pass for torch.sign() returns all gradients simply as zeros and not the output expected from a straight through estimator.

Ok I see. But first, the correct straight through estimator for the derivative of sign is not clamping grad_output between -1 and 1. You want grad_output being 0 where the input is smaller than -1 or bigger than 1. So, something like this:

def backward(self, grad_output):
    input = self.saved_tensors
    grad_output[input>1]=0
    grad_output[input<-1]=0
    return grad_output

Then, after you define your function extanding Function, you can call it into the forward your module just like in the example above

5 Likes

Really sorry about making a mistake in the straight through estimator, probably was half asleep :stuck_out_tongue: Thank you very much for the help, my layer is working now after calling the class inherited from Function in my layer inherited from nn.Module :smiley:

4 Likes