One possible approach is to implement your own round function with customized backpropagation.
E.g.
class my_round_func(InplaceFunction):
@staticmethod
def forward(ctx, input):
ctx.input = input
return torch.round(input)
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input