Hi all, I use torch.eq operation in my network recently. And I find there is no grad_fn for this function, which means the loss cannot backward. Therefore, I want to define a differentiable equal function by myself. Here is my code.
class Eq(torch.autograd.Function):
@staticmethod
def forward(ctx, input, value):
mask = torch.eq(input, value)
idx = mask.nonzero()
ctx._input_shape = input.shape
ctx._input_dtype = input.dtype
ctx._input_device = input.device
ctx.save_for_backward(idx)
return mask.float()
@staticmethod
def backward(ctx, grad_output):
idx, = ctx.saved_tensors
grad_input = torch.zeros(ctx._input_shape, device=ctx._input_device, dtype=ctx._input_dtype)
grad_input[idx[:,0], idx[:,1], idx[:,2]] = grad_output[idx[:,0], idx[:,1], idx[:,2]]
return grad_input, None
However, I am not sure where it is right or wrong. And the funtion gradcheck seems not applicable to this situation. Can anyone help me solve this problem?
Thanks,