Forward gradient from one Variable to another in XNOR net

I’m trying to implement kind of Binary NN (think of XNOR, for example), where forward() pass uses its binarized form (torch.sing(weights)) for inference, but uses the original full precision self.weight for backward optimization step (gradients are computed w.r.t. torch.sing(weights)).

class BinaryLinear(nn.Linear):
    def forward(self, x):
        return F.linear(x, self.weight.sign(), self.bias)

But self.weight.sign() creates a new Variable, and gradients do not flow back to the full precision self.weight. The workaround is

class BinaryLinear(nn.Linear):
    def forward(self, x):
        weight_clone = self.weight.data.clone()
        self.weight.data.sign_()
        x = F.linear(x, self.weight, self.bias)
        self.weight.data = weight_clone
        return x

Too much overhead for one simple operation. How can I correctly implement the first case - forward gradients from a local torch.weight.sign() to my self.weight during the backward pass / optimization step?

sign is a discrete function whose gradient is zero everywhere. Theoretically the workaround cannot be working.

I think you need to implement a custom backward function that would calculate the gradient of the signed weights and then copy that gradient over to the full precision weights.

Another workaround is to create custom sign() function that does not modify incoming gradients.

class SignPassGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input.sign()

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

class BinaryLinear(nn.Linear):
    def forward(self, x):
        return F.linear(x, SignPassGrad.apply(self.weight), self.bias)

And it does not yield any speedup (probably due to fast memory copies in torch.clone).

1 Like