How to overwrite a backwards pass

I want to do something like :

from torch.autograd import Function
class BinaryLayer(Function):
    def forward(self, input):
        return (input > .5).float()

    def backward(self, grad_output):
        return grad_output

Where the forward would be a binary activation and the backward would be linear.

I get the following error:
RuntimeError: could not compute gradients for some functions

when I use said function.

How can I get it to work?
Thanks in advance,
Dan

ok - I went through the documentation and tried to stuff. I’m still not sure why this new version works but the following works for me ( for anyone who runs into the same issue ):

import torch
from torch import nn
from torch.autograd import Function
from torch.optim import SGD


class BinaryActivation(Function):

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x.round()

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clone()


class BinaryLayer(Function):
    def forward(self, input):
        return input.round()

    def backward(self, grad_output):
        return grad_output


class SkipRNN(nn.Module):
    def __init__(self, c_in=10, c_hidden=10):
        super(SkipRNN, self).__init__()

        self.hidden_layer = nn.Linear(c_in, c_hidden)
        self.gate = nn.Sequential(*[nn.Linear(c_hidden, 1), nn.Sigmoid()])
        self.num_hidden = c_hidden

    def forward(self, x):
        '''x.shape = [batch, time_steps, feaures]'''
        bn = BinaryActivation.apply
        u_t = torch.zeros((x.size(0),1)).float()
        s_t = torch.zeros((x.size(0), self.num_hidden)).float()
        out = torch.zeros((x.size(0), x.size(1), self.num_hidden))
        for t in range(x.size(1)):
            u_t_bin = bn(u_t)
            s_t = u_t_bin * self.hidden_layer(x[:, t, :]) + (1 - u_t_bin) * s_t
            del_u_t = self.gate(s_t)
            u_t = u_t_bin * del_u_t + (1 - u_t_bin) * (u_t + torch.min(del_u_t, 1 - u_t))
            out[:, t, :] = s_t

        return out


def basic_check():
    learning_rate = .1
    x = torch.rand((8, 5)).float()
    y = torch.rand((8, 5)).float()
    # Create random Tensors for weights.
    w1 = torch.randn(5, 10, dtype=torch.float, requires_grad=True)
    w2 = torch.randn(10, 5, dtype=torch.float, requires_grad=True)
    for t in range(50):
        # bn = BinaryActivation.apply
        bn = BinaryLayer()

        y_pred = bn(x.mm(w1)).mm(w2)
        loss = (y_pred - y).pow(2).mean()
        loss.backward()

        with torch.no_grad():
            w1 -= learning_rate * w1.grad
            w2 -= learning_rate * w2.grad

            # Manually zero the gradients after updating weights
            w1.grad.zero_()
            w2.grad.zero_()


def skip_rnn_check():
    learning_rate = .1
    x = torch.rand((8, 20, 10)).float()
    y = torch.rand((8, 20, 10)).float()

    model = SkipRNN(10, 10)

    optimizer = SGD(model.parameters(), lr=.1)

    for t in range(50):
        optimizer.zero_grad()
        y_pred = model(x)
        loss = (y_pred - y).pow(2).mean()
        loss.backward()
        optimizer.step()

    hi = 5


if __name__ == '__main__':
    basic_check()
    skip_rnn_check()
    hi = 5

1 Like

If anyone has an explanation why BinaryActivation works and BinaryLayer doesn’t i’d love to better understand this.

Hi,

You can check the doc about extending the autograd.
The forward and backward funcitons must be static functions.

1 Like

ok , that makes sense. Thanks!