Use existing gradient computations in autograd function

I would like to do some gradient hacking on a linear layer. The most elegant way I found to do this is to write a new gradient function as follows:

def custom_grad(s1=1, s2=1):
    class Custom(Function):

        @staticmethod
        def forward(ctx, inputs, weights):
            ctx.save_for_backward(inputs, weights)

            outputs = inputs @ weights.t()
            return s1 * outputs

        @staticmethod
        def backward(ctx, grad_output):
            inputs, weights = ctx.saved_tensors
            grad_input = grad_output.clone()

            dx = grad_input @ weights
            dw = grad_input.unsqueeze(-1) @ inputs.unsqueeze(1)

            return s2 * dx, dw.sum(0)

    return Custom.apply

The problem with this approach, however, is that it is so extremely slow (it takes twice as long) due to the two matrix multiplications. I tried using something like dx, dw = torch.autograd.grad(outputs, [inputs, weights], grad_input) to speed things up, but for some reason outputs is a leaf variable and therefore does not allow gradients to flow through.

In the end, I only need to rescale some gradients, so it would be nice if I could just use the default autograd functionality, but it seems to be nearly impossible.

PS: I’m working with pytorch 0.4.0 for now

Hi,

You can use hooks for that.
Register a hook on the tensor whose gradient you want to scale, and return the scaled gradient.

I was hoping I could avoid the hooks, since it blurs the actual idea behind the gradient hacking. Seems like I don’t have much of a choice here then…

Well the hook are supposed to be the right way to do gradient hacking :wink:

@albanD Any ideas on how I can keep the effect of the gradient hacking local in a forward pass of a module?

def forward(self, x):
    raw = x @ self.weight.t()
    y = self.s1 * raw
    x.register_hook(lambda grad: grad * self.s2 / self.s1)
    return y

How can I prevent other gradients w.r.t. the tensor passed to x to be affected by this hook? I assume something like

h = x.register_hook(lambda grad: grad * self.s2 / self.s1)
h.remove()

would be something like a NO-OP.

Hi,

You can add x = x.clone() at the beginning of your forward function.

This seems to be working, thanks! The only issue is when the inputs do not require gradients (i.e. in the input layer). However, this can easily be resolved by using something like:

def forward(self, x):
    if x.requires_grad:
         x = x.clone()
         x.register_hook(lambda grad: grad * self.s2 / self.s1)

    raw = x @ self.weight.t()
    return self.s1 * raw