How to pass intermediate value in forward function into self class instantiation?

My problem is best illustrated with an example.
I wrapped the autograd.Function class into the uniform function and used it in the forward function of target_fn class, which works fine as below.

def uniform(k):
  class qfn(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
      ...
      return out

    @staticmethod
    def backward(ctx, grad_output):
      ...
      return grad_input
  
  return qfn().apply
  
class target_fn(nn.Module):
  def __init__(self, bit):
    super(target_fn, self).__init__()
    ...
    self.uni = uniform(k=bit)  

  def forward(self, x):
#     self.max = max(x)
    output = self.uni(x)
    return output  

However, I want the self.max computed in the target_fn.forward to be utilized by the uniform for some preprocessing. And I tried something like below but error prompted:

def uniform(k, max_value):
  class qfn(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
      ...
      out = out / max_value
      return out

    @staticmethod
    def backward(ctx, grad_output):
      ...
      return grad_input
    
  return qfn().apply
  
class target_fn(nn.Module):
  def __init__(self, bit):
    super(target_fn, self).__init__()
    ...
    self.uni = uniform(k=bit, max_value = self.max)   # error: target_fn instantiation does not have self.max

  def forward(self, x):
    self.max = max(x)
    output = self.uni(x)
    return output

Funny thing is that I can actually access the target_fn.max in debugger, but the line self.uni = uniform(k=bit, max_value = self.max) cannot.
How could I fix it? Thank you.

Hi,

Why not just pass the self.max as a second argument to your qfn custom Function?
If it does not require gradients, you can simply return None in the backward pass for it.

Hi Alban,

Thank you for your reply. Sorry that I don’t quite get what you meant. As you see, in the second snippet, I have passed the self.max as the second argument of the uniform function. Could you elaborate on your suggestion?

I meant the qfn function so that the forward is def forward(ctx, input, max_value).

Hi,
Actually the self.max is determined in the target_fn.forward. I think my question could be better rephrased like this:
How can we access the self.max in the target.__init__ to get through this error? I know register_forward_hook could help on the activation interception in forward, I am just curious if we need to pass the self.max as an argument in the uniform() in target_fn.__init__, what should I do?

I don’t understand why you want max to be saved on the Module?
And why you cannot just pass it along?

def uniform(k):
  class qfn(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, max_value):
      ...
      out = out / max_value
      return out

    @staticmethod
    def backward(ctx, grad_output):
      ...
      EDIT: need to return None for non-differentiable inputs
      return grad_input, None
  
  return qfn().apply

class target_fn(nn.Module):
  def __init__(self, bit):
    super(target_fn, self).__init__()
    ...
    self.uni = uniform(k=bit)

  def forward(self, x):
    max = max(x)
    output = self.uni(x, max)
    return output

I don’t know if pass into forward function is a good idea? I tried your solution but prompted RuntimeError: function qfnBackward returned an incorrect number of gradients (expected 2, got 1). I think the backward might consider the max_value for backprop also, so I tried to set the max_value.requires_grad = False. But the error still remained.

def uniform(k):
  class qfn(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, max_value):
      ...
      out = out / max_value
      return out

    @staticmethod
    def backward(ctx, grad_output):
      ...
      return grad_input
  
  return qfn().apply

class target_fn(nn.Module):
  def __init__(self, bit):
    super(target_fn, self).__init__()
    ...
    self.uni = uniform(k=bit)

  def forward(self, x):
    max = max(x)
    max.requires_grad = False
    output = self.uni(x, max)
    return output

The reason I want max to be saved on the module is to properly construct the uniform. You see, the uniform.qfn.forward contains some logic based on the max.
Thanks again, albanD.

Ho sorry I forgot to update the backward in my code sample above :confused: Done now.
Any input that is not a Tensor that requires grad (either not a Tensor or a Tensor that does not require gradients) should have the backward return None for them.

Hi, I did the return grad_input, None but the error still occurred.
RuntimeError: function qfnBackward returned an incorrect number of gradients (expected 2, got 1)
Could you reproduce this error on your side? If the forward has 2 arguments excluding the context ctx, and if the backward really gonna work with grad_input, None? Thanks again.

Yes that does work on my end. You should double check your code to make sure you don’t have an old version of the custom Function around.