Thresholding output of U-Net

Hi I’m working on a 3D binary semantic segmentation using the U-net. I want to extract the probability mask from the prediction an visualize them and and also apply a threshold to regulate the segmentation. For the current code threshold = 0.6 results in larger segmentation masks than threshold = 0.4 which is wrong. Does anyone have an idea of how to solve this?

### predict masks
threshold = 0.6
for sub in subject_list:
    data_root = os.path.join(work_dir, 'dataset')
    out_dir = os.path.join(work_dir, 'output')
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    images = sorted(glob.glob(os.path.join(data_root, 'images', '%s.nii.gz' %sub)))
    data_dicts = [{'image': image_name}
                  for image_name in zip(images)]
    val_files = data_dicts

    val_transforms = Compose([
        LoadNiftid(keys=['image']),
        AddChanneld(keys=['image']),
        NormalizeIntensityd(keys=['image']),
        Rotate90d(keys=['image'], spatial_axes=(1,0)),
        ToTensord(keys=['image'])
        ])


    check_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    check_loader = DataLoader(check_ds, batch_size=2)
    check_data = monai.utils.misc.first(check_loader)



    device = torch.device('cpu')
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)
    img_saver = NiftiSaver(output_dir=out_dir, output_postfix='img', output_ext='.nii.gz')
    mask_saver = NiftiSaver(output_dir=out_dir, output_postfix='mask', output_ext='.nii.gz')
    
    
    prob_saver = NiftiSaver(output_dir=out_dir, output_postfix='prob', output_ext='.nii.gz')


    model = monai.networks.nets.UNet(dimensions=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256),
                                     strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH).to(device)
    model.load_state_dict(torch.load(seg_model))
    model.eval()
    with torch.no_grad():
        for i, val_data in enumerate(val_loader):


            val_inputs = val_data['image'].to(device)
            roi_size = (32, 64, 32)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)

            # Convert logits to probabilities using a sigmoid function
            val_probs = torch.sigmoid(val_outputs)
            
                            
            # Apply threshold to convert probabilities to binary predictions
            custom_threshold = threshold 

            # Apply custom threshold
            val_preds = (val_probs > custom_threshold).float()

            # Now val_preds contains the binary segmentation masks
            seg = torch.argmax(val_preds, dim=1).detach().cpu()[0]
            

            # saving images
            t2 = nib.load(os.path.join(work_dir, 'cut_dir', 'images',  '%s.nii.gz' %sub))
            img = (val_data['image'][0][0])
            img = torch.squeeze(img)
            img_saver.save(data=img)
            load = nib.load(os.path.join(out_dir, '%s/%s_img.nii.gz' %(i,i))).get_fdata()
            load = np.squeeze(load)

            img = np.flip(load, 2)
            img = np.flip(img, 0)
            img = np.flip(img, 2)
            img = np.swapaxes(img,  -1, 1)


            nifti1 = nib.Nifti1Image(img, t2.affine)
            nib.save(nifti1, os.path.join(out_dir, '%s_img.nii.gz' %sub))
            img = Image(os.path.join(out_dir, '%s_img.nii.gz' %sub))


            res_img = resample(img, t2.shape,  order=3)
            res_img = nib.Nifti1Image(res_img[0], t2.affine, t2.header)
            res_img.to_filename(os.path.join(out_dir, '%s_img_res.nii.gz' %sub))

            # saving probability masks 
            prob_saver.save(data=val_probs)
            load = nib.load(os.path.join(out_dir, '%s/%s_prob.nii.gz' %(i,i)))
            load = nib.load(os.path.join(out_dir, '%s/%s_prob.nii.gz' %(i,i))).get_fdata()
            load = np.squeeze(load)
            img = np.flip(load, 2)
            img = np.flip(img, 0)
            img = np.flip(img, 2)
            img = np.swapaxes(img,  -1, 1)



            nifti1 = nib.Nifti1Image(img, t2.affine)
            nib.save(nifti1, os.path.join(out_dir, '%s_prob.nii.gz' %sub))
            img = Image(os.path.join(out_dir, '%s_prob.nii.gz' %sub))
            '''res_img = resample(img, t2.shape,  order=0)
            res_img = nib.Nifti1Image(res_img[0], t2.affine, t2.header)
            res_img.to_filename(os.path.join(out_dir, '%s_prob_res.nii.gz' %sub))'''

            # saving segmentation masks 
            mask_saver.save(data=seg)
            load = nib.load(os.path.join(out_dir, '%s/%s_mask.nii.gz' %(i,i)))
            load = nib.load(os.path.join(out_dir, '%s/%s_mask.nii.gz' %(i,i))).get_fdata()
            load = np.squeeze(load)
            img = np.flip(load, 2)
            img = np.flip(img, 0)
            img = np.flip(img, 2)
            img = np.swapaxes(img,  -1, 1)



            nifti1 = nib.Nifti1Image(img, t2.affine)
            nib.save(nifti1, os.path.join(out_dir, '%s_seg.nii.gz' %sub))
            img = Image(os.path.join(out_dir, '%s_seg.nii.gz' %sub))
            res_img = resample(img, t2.shape,  order=0)
            res_img = nib.Nifti1Image(res_img[0], t2.affine, t2.header)
            res_img.to_filename(os.path.join(out_dir, '%s_seg_res.nii.gz' %sub))

You shouldn’t worry about segmentation, what you need to do is look at your hardware. Are you on a mac? Would recommend to start there