How to overwrite a backwards pass

I want to do something like :

from torch.autograd import Function
class BinaryLayer(Function):
    def forward(self, input):
        return (input > .5).float()

    def backward(self, grad_output):
        return grad_output

Where the forward would be a binary activation and the backward would be linear.

I get the following error:
RuntimeError: could not compute gradients for some functions

when I use said function.

How can I get it to work?
Thanks in advance,

ok - I went through the documentation and tried to stuff. I’m still not sure why this new version works but the following works for me ( for anyone who runs into the same issue ):

import torch
from torch import nn
from torch.autograd import Function
from torch.optim import SGD

class BinaryActivation(Function):

    def forward(ctx, x):
        return x.round()

    def backward(ctx, grad_output):
        return grad_output.clone()

class BinaryLayer(Function):
    def forward(self, input):
        return input.round()

    def backward(self, grad_output):
        return grad_output

class SkipRNN(nn.Module):
    def __init__(self, c_in=10, c_hidden=10):
        super(SkipRNN, self).__init__()

        self.hidden_layer = nn.Linear(c_in, c_hidden)
        self.gate = nn.Sequential(*[nn.Linear(c_hidden, 1), nn.Sigmoid()])
        self.num_hidden = c_hidden

    def forward(self, x):
        '''x.shape = [batch, time_steps, feaures]'''
        bn = BinaryActivation.apply
        u_t = torch.zeros((x.size(0),1)).float()
        s_t = torch.zeros((x.size(0), self.num_hidden)).float()
        out = torch.zeros((x.size(0), x.size(1), self.num_hidden))
        for t in range(x.size(1)):
            u_t_bin = bn(u_t)
            s_t = u_t_bin * self.hidden_layer(x[:, t, :]) + (1 - u_t_bin) * s_t
            del_u_t = self.gate(s_t)
            u_t = u_t_bin * del_u_t + (1 - u_t_bin) * (u_t + torch.min(del_u_t, 1 - u_t))
            out[:, t, :] = s_t

        return out

def basic_check():
    learning_rate = .1
    x = torch.rand((8, 5)).float()
    y = torch.rand((8, 5)).float()
    # Create random Tensors for weights.
    w1 = torch.randn(5, 10, dtype=torch.float, requires_grad=True)
    w2 = torch.randn(10, 5, dtype=torch.float, requires_grad=True)
    for t in range(50):
        # bn = BinaryActivation.apply
        bn = BinaryLayer()

        y_pred = bn(
        loss = (y_pred - y).pow(2).mean()

        with torch.no_grad():
            w1 -= learning_rate * w1.grad
            w2 -= learning_rate * w2.grad

            # Manually zero the gradients after updating weights

def skip_rnn_check():
    learning_rate = .1
    x = torch.rand((8, 20, 10)).float()
    y = torch.rand((8, 20, 10)).float()

    model = SkipRNN(10, 10)

    optimizer = SGD(model.parameters(), lr=.1)

    for t in range(50):
        y_pred = model(x)
        loss = (y_pred - y).pow(2).mean()

    hi = 5

if __name__ == '__main__':
    hi = 5

If anyone has an explanation why BinaryActivation works and BinaryLayer doesn’t i’d love to better understand this.


You can check the doc about extending the autograd.
The forward and backward funcitons must be static functions.

1 Like

ok , that makes sense. Thanks!