Hi I’m working on a 3D binary semantic segmentation using the U-net. I want to extract the probability mask from the prediction an visualize them and and also apply a threshold to regulate the segmentation. For the current code threshold = 0.6 results in larger segmentation masks than threshold = 0.4 which is wrong. Does anyone have an idea of how to solve this?
### predict masks
threshold = 0.6
for sub in subject_list:
data_root = os.path.join(work_dir, 'dataset')
out_dir = os.path.join(work_dir, 'output')
if not os.path.exists(out_dir):
os.makedirs(out_dir)
images = sorted(glob.glob(os.path.join(data_root, 'images', '%s.nii.gz' %sub)))
data_dicts = [{'image': image_name}
for image_name in zip(images)]
val_files = data_dicts
val_transforms = Compose([
LoadNiftid(keys=['image']),
AddChanneld(keys=['image']),
NormalizeIntensityd(keys=['image']),
Rotate90d(keys=['image'], spatial_axes=(1,0)),
ToTensord(keys=['image'])
])
check_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=2)
check_data = monai.utils.misc.first(check_loader)
device = torch.device('cpu')
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)
img_saver = NiftiSaver(output_dir=out_dir, output_postfix='img', output_ext='.nii.gz')
mask_saver = NiftiSaver(output_dir=out_dir, output_postfix='mask', output_ext='.nii.gz')
prob_saver = NiftiSaver(output_dir=out_dir, output_postfix='prob', output_ext='.nii.gz')
model = monai.networks.nets.UNet(dimensions=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH).to(device)
model.load_state_dict(torch.load(seg_model))
model.eval()
with torch.no_grad():
for i, val_data in enumerate(val_loader):
val_inputs = val_data['image'].to(device)
roi_size = (32, 64, 32)
sw_batch_size = 4
val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
# Convert logits to probabilities using a sigmoid function
val_probs = torch.sigmoid(val_outputs)
# Apply threshold to convert probabilities to binary predictions
custom_threshold = threshold
# Apply custom threshold
val_preds = (val_probs > custom_threshold).float()
# Now val_preds contains the binary segmentation masks
seg = torch.argmax(val_preds, dim=1).detach().cpu()[0]
# saving images
t2 = nib.load(os.path.join(work_dir, 'cut_dir', 'images', '%s.nii.gz' %sub))
img = (val_data['image'][0][0])
img = torch.squeeze(img)
img_saver.save(data=img)
load = nib.load(os.path.join(out_dir, '%s/%s_img.nii.gz' %(i,i))).get_fdata()
load = np.squeeze(load)
img = np.flip(load, 2)
img = np.flip(img, 0)
img = np.flip(img, 2)
img = np.swapaxes(img, -1, 1)
nifti1 = nib.Nifti1Image(img, t2.affine)
nib.save(nifti1, os.path.join(out_dir, '%s_img.nii.gz' %sub))
img = Image(os.path.join(out_dir, '%s_img.nii.gz' %sub))
res_img = resample(img, t2.shape, order=3)
res_img = nib.Nifti1Image(res_img[0], t2.affine, t2.header)
res_img.to_filename(os.path.join(out_dir, '%s_img_res.nii.gz' %sub))
# saving probability masks
prob_saver.save(data=val_probs)
load = nib.load(os.path.join(out_dir, '%s/%s_prob.nii.gz' %(i,i)))
load = nib.load(os.path.join(out_dir, '%s/%s_prob.nii.gz' %(i,i))).get_fdata()
load = np.squeeze(load)
img = np.flip(load, 2)
img = np.flip(img, 0)
img = np.flip(img, 2)
img = np.swapaxes(img, -1, 1)
nifti1 = nib.Nifti1Image(img, t2.affine)
nib.save(nifti1, os.path.join(out_dir, '%s_prob.nii.gz' %sub))
img = Image(os.path.join(out_dir, '%s_prob.nii.gz' %sub))
'''res_img = resample(img, t2.shape, order=0)
res_img = nib.Nifti1Image(res_img[0], t2.affine, t2.header)
res_img.to_filename(os.path.join(out_dir, '%s_prob_res.nii.gz' %sub))'''
# saving segmentation masks
mask_saver.save(data=seg)
load = nib.load(os.path.join(out_dir, '%s/%s_mask.nii.gz' %(i,i)))
load = nib.load(os.path.join(out_dir, '%s/%s_mask.nii.gz' %(i,i))).get_fdata()
load = np.squeeze(load)
img = np.flip(load, 2)
img = np.flip(img, 0)
img = np.flip(img, 2)
img = np.swapaxes(img, -1, 1)
nifti1 = nib.Nifti1Image(img, t2.affine)
nib.save(nifti1, os.path.join(out_dir, '%s_seg.nii.gz' %sub))
img = Image(os.path.join(out_dir, '%s_seg.nii.gz' %sub))
res_img = resample(img, t2.shape, order=0)
res_img = nib.Nifti1Image(res_img[0], t2.affine, t2.header)
res_img.to_filename(os.path.join(out_dir, '%s_seg_res.nii.gz' %sub))