I am suspecting that torch.where
may not be backpropagating gradients.
If you really want to threshold, try to use straight-through estimator trick as follows:
thresholded_inputs = torch.where(thresholded_inputs < threshold, 0, 1)
inputs = (inputs + thresholded_inputs) - inputs.detach()
... calculate IoU loss ...