I modify a custom autograd function like this
I thought this error is caused by multiple return values from function A
but I need min scale to make other codes run properly.
class A(torch.autograd.Function):
@staticmethod
def forward(ctx, input, qparams=None):
…
return output, min, scale
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output
return grad_input, None
class B(nn.Module):
def init(self):
…
def forward(self, input, num_bits, qparams=None):
…
output, min, scale = A().apply(input, qparams)
return output
The backward would expect the same number of input arguments as were returned in the forward
method, so you would have to add these arguments as described in the backward
section of this doc.