I’m trying to implement kind of Binary NN (think of XNOR, for example), where forward()
pass uses its binarized form (torch.sing(weights)
) for inference, but uses the original full precision self.weight
for backward optimization step (gradients are computed w.r.t. torch.sing(weights)
).
class BinaryLinear(nn.Linear):
def forward(self, x):
return F.linear(x, self.weight.sign(), self.bias)
But self.weight.sign()
creates a new Variable
, and gradients do not flow back to the full precision self.weight
. The workaround is
class BinaryLinear(nn.Linear):
def forward(self, x):
weight_clone = self.weight.data.clone()
self.weight.data.sign_()
x = F.linear(x, self.weight, self.bias)
self.weight.data = weight_clone
return x
Too much overhead for one simple operation. How can I correctly implement the first case - forward gradients from a local torch.weight.sign()
to my self.weight
during the backward pass / optimization step?