My problem is best illustrated with an example.
I wrapped the autograd.Function class into the uniform function and used it in the forward function of target_fn class, which works fine as below.
However, I want the self.max computed in the target_fn.forward to be utilized by the uniform for some preprocessing. And I tried something like below but error prompted:
def uniform(k, max_value):
class qfn(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
...
out = out / max_value
return out
@staticmethod
def backward(ctx, grad_output):
...
return grad_input
return qfn().apply
class target_fn(nn.Module):
def __init__(self, bit):
super(target_fn, self).__init__()
...
self.uni = uniform(k=bit, max_value = self.max) # error: target_fn instantiation does not have self.max
def forward(self, x):
self.max = max(x)
output = self.uni(x)
return output
Funny thing is that I can actually access the target_fn.max in debugger, but the line self.uni = uniform(k=bit, max_value = self.max) cannot.
How could I fix it? Thank you.
Why not just pass the self.max as a second argument to your qfn custom Function?
If it does not require gradients, you can simply return None in the backward pass for it.
Thank you for your reply. Sorry that I don’t quite get what you meant. As you see, in the second snippet, I have passed the self.max as the second argument of the uniform function. Could you elaborate on your suggestion?
Hi,
Actually the self.max is determined in the target_fn.forward. I think my question could be better rephrased like this:
How can we access the self.max in the target.__init__ to get through this error? I know register_forward_hook could help on the activation interception in forward, I am just curious if we need to pass the self.max as an argument in the uniform() in target_fn.__init__, what should I do?
I don’t know if pass into forward function is a good idea? I tried your solution but prompted RuntimeError: function qfnBackward returned an incorrect number of gradients (expected 2, got 1). I think the backward might consider the max_value for backprop also, so I tried to set the max_value.requires_grad = False. But the error still remained.
def uniform(k):
class qfn(torch.autograd.Function):
@staticmethod
def forward(ctx, input, max_value):
...
out = out / max_value
return out
@staticmethod
def backward(ctx, grad_output):
...
return grad_input
return qfn().apply
class target_fn(nn.Module):
def __init__(self, bit):
super(target_fn, self).__init__()
...
self.uni = uniform(k=bit)
def forward(self, x):
max = max(x)
max.requires_grad = False
output = self.uni(x, max)
return output
The reason I want max to be saved on the module is to properly construct the uniform. You see, the uniform.qfn.forward contains some logic based on the max.
Thanks again, albanD.
Ho sorry I forgot to update the backward in my code sample above Done now.
Any input that is not a Tensor that requires grad (either not a Tensor or a Tensor that does not require gradients) should have the backward return None for them.
Hi, I did the return grad_input, None but the error still occurred. RuntimeError: function qfnBackward returned an incorrect number of gradients (expected 2, got 1)
Could you reproduce this error on your side? If the forward has 2 arguments excluding the context ctx, and if the backward really gonna work with grad_input, None? Thanks again.