I want to implement my own quantized and clipped ReLU. This is how I implemented it:
class _quantAct(torch.autograd.Function):
@staticmethod
def forward(ctx, input, clip_low=0., clip=6., bits=8, inplace=False):
if inplace:
ctx.mark_dirty(input)
output = input
else:
output = input.clone()
output[output<clip_low]=clip_low
output[output>clip]=clip
output = output.div(clip).mul((2**bits)-1).round().div((2**bits)-1).mul(clip)
ctx.save_for_backward(output.eq(clip_low)+output.eq(clip))
return output
@staticmethod
def backward(ctx, grad_output):
# saved tensors - tuple of tensors with one element
mask, = ctx.saved_tensors
grad_input = grad_output.masked_fill(mask,0)
return grad_input, None, None, None, None
class quantReLU(nn.ReLU):
def __init__(self, clip=6., bits=8, inplace=False):
super(quantReLU, self).__init__()
self.clip = clip
self.bits = bits
self.inplace = inplace
def forward(self, inputs):
return _quantAct().apply(inputs, 0, self.clip, self.bits, self.inplace)
How many grads do I have to return from the static backward method of my torch.autograd.Function inherited class? Why does it expect me to return 5 of them?
Appreciate your inputs, thanks!