Differentiable equal

Hi all, I use torch.eq operation in my network recently. And I find there is no grad_fn for this function, which means the loss cannot backward. Therefore, I want to define a differentiable equal function by myself. Here is my code.

class Eq(torch.autograd.Function):

    def forward(ctx, input, value):
        mask = torch.eq(input, value)
        idx = mask.nonzero()
        ctx._input_shape = input.shape
        ctx._input_dtype = input.dtype
        ctx._input_device = input.device
        return mask.float()

    def backward(ctx, grad_output):
        idx, = ctx.saved_tensors
        grad_input = torch.zeros(ctx._input_shape, device=ctx._input_device, dtype=ctx._input_dtype)
        grad_input[idx[:,0], idx[:,1], idx[:,2]] = grad_output[idx[:,0], idx[:,1], idx[:,2]]
        return grad_input, None

However, I am not sure where it is right or wrong. And the funtion gradcheck seems not applicable to this situation. Can anyone help me solve this problem?

1 Like

Well, after careful consideration, equal operation cannot back propogate loss. Trying to define such a function equals banging my head against a brick wall.

1 Like