How to add additional information for backward?

Hi, I’m new learner of pytorch, I tried to add very simple information for backward() from layer to layer besides output_grad. However, it wouldn’t work even for the following extremely simple code. Can someone help me on how to fix this ? Thanks a lot!

import torch
import torch.autograd as autograd
import torch.nn as nn

class MyFun(torch.autograd.Function):
    def forward(self, inp):
        return inp

    def backward(self, grad_out, P):
        grad_input = grad_out.clone()
        print('Custom backward called!')
        return grad_input, P-1

class MyMod(nn.Module):
    def forward(self, x):
         return MyFun()(x)

mod1 = MyMod()

y = autograd.Variable(torch.randn(1), requires_grad=True)
z = mod1(y)
P = autograd.Variable(torch.ones((1,1)))

TypeError: backward() missing 1 required positional argument: ‘P’

  1. Your backward won’t be called because you return the original input ‘inp’. (This is changed in 0.3, but until then return inp.clone())
  2. Your backward takes in too many arguments. It should only take in grad_out because the forward only returns a single argument.
  3. Your backward returns too many values. It should only return a single value because the forward only takes in a single value.

Thanks a lot ! If I just want to backprop some P which doesn’t need to do anything for forward(). Do you have any suggestions on where I should add ? Is it better to just define another function in the nn.Module ? Thanks again for your help.

Any updates on whether this is possible? The context is the following: I want to implement a custom backward function for an invertible layer. Instead of saving the activations during the forward pass, the backward function takes the original output y as and additional input argument, from which the input x is recomputed and used to compute gradients. I therefore need my custom backward function to not only accept the gradient as an input, but also the original output (in the example of @mathkobe this is the additional argument P).
I added the additional argument to the backward functions in torch/ and torch/autograd/ But after that, the functions step into the C code and I cannot follow them with the Python debugger.
Any help or pointers on how to extend this would be great! Thanks a lot!