Binary weight for forward and real weights for backward pass

I want to create a binary linear layer where the weights would be binary however during the backward pass i want to retain the real valued weights. Something like this one in torch:

How do i handle the weight binary weight during only forward pass and reall weights during the backward pass?

The right and clean solution would be to create a custom autograd Function for this:

See this page for guidance on writing custom autograd Functions http://pytorch.org/docs/master/notes/extending.html

1 Like