Backward() takes 2 positional arguments but 4 were given

I modify a custom autograd function like this
I thought this error is caused by multiple return values from function A
but I need min scale to make other codes run properly.
class A(torch.autograd.Function):
def forward(ctx, input, qparams=None):

return output, min, scale
def backward(ctx, grad_output):
grad_input = grad_output
return grad_input, None
class B(nn.Module):
def init(self):

def forward(self, input, num_bits, qparams=None):

output, min, scale = A().apply(input, qparams)
return output

The backward would expect the same number of input arguments as were returned in the forward method, so you would have to add these arguments as described in the backward section of this doc.