I would like to do some gradient hacking on a linear layer. The most elegant way I found to do this is to write a new gradient function as follows:
def custom_grad(s1=1, s2=1):
class Custom(Function):
@staticmethod
def forward(ctx, inputs, weights):
ctx.save_for_backward(inputs, weights)
outputs = inputs @ weights.t()
return s1 * outputs
@staticmethod
def backward(ctx, grad_output):
inputs, weights = ctx.saved_tensors
grad_input = grad_output.clone()
dx = grad_input @ weights
dw = grad_input.unsqueeze(-1) @ inputs.unsqueeze(1)
return s2 * dx, dw.sum(0)
return Custom.apply
The problem with this approach, however, is that it is so extremely slow (it takes twice as long) due to the two matrix multiplications. I tried using something like dx, dw = torch.autograd.grad(outputs, [inputs, weights], grad_input)
to speed things up, but for some reason outputs
is a leaf variable and therefore does not allow gradients to flow through.
In the end, I only need to rescale some gradients, so it would be nice if I could just use the default autograd functionality, but it seems to be nearly impossible.
PS: I’m working with pytorch 0.4.0 for now