mask = (pred_cc_volume != 0) & (~torch.isin(pred_cc_volume, tp))
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument test_elements in method wrapper_CUDA_isin_Tensor_Tensor)
Any way to gauge which is on what device from this error? Or is it worth putting in an effort to generate a pr/issue?