How to select elements in tensor greater than and keep gradient?

you can use torch.gt() or torch.le() . it results in a tensor with the same dimension as the input tensor then you can do what you want.

1 Like