Torch.round() gradient

One possible approach is to implement your own round function with customized backpropagation.

E.g.

class my_round_func(InplaceFunction):
    @staticmethod
    def forward(ctx, input):
        ctx.input = input
        return torch.round(input)

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return grad_input
1 Like