I am building a model to segment microscopy images. My lab in question produces a very high volume of images, thus analysis time is an important factor for us. Thus, I builded a simple UNet, which does a good job at segmenting. However, I see the inference quite slow. Thus, I was wondering if someone could double check my code to see if I am doing something wrong so that slows down my code by a significant factor:
Here is the code for my inference function:
def predict(self, image):
"""
Predicts the segmentation mask for the given image. It can handle 2D images or a stack of 2D images.
Args:
image (numpy.ndarray): The input image or image stack.
Returns:
numpy.ndarray: The segmentation mask or stack of segmentation masks.
"""
image = normalize_percentile(image)
# Check if the image is a stack
if len(image.shape) == 3:
# Store the original image size and number of slices
original_size = image.shape[1:]
num_slices = image.shape[0]
img_tensor = torch.from_numpy(image).unsqueeze(1).to(self.device) # Shape: [Batch, Channels, H, W]
if self.half_precision:
img_tensor = img_tensor.half() # Convert input to half-precision
inferer = SlidingWindowInferer(roi_size=self.patch_size, sw_batch_size=1, overlap=self.overlap_ratio)
with torch.no_grad():
output_mask = inferer(img_tensor, self.model)
output_mask = output_mask.squeeze(0).cpu().numpy()
else:
img_tensor = torch.from_numpy(image).unsqueeze(0).unsqueeze(0).to(self.device)
if self.half_precision:
img_tensor = img_tensor.half()
inferer = SlidingWindowInferer(roi_size=self.patch_size, sw_batch_size=350, overlap=self.overlap_ratio)
with torch.no_grad():
output_mask = inferer(img_tensor, self.model)
output_mask = output_mask.squeeze(0).squeeze(0).cpu().numpy()
# Free up tensors
del img_tensor, image
gc.collect()
empty_gpu_cache(self.device)
return output_mask
At the moment, it takes ~4s per image, which is ~3min per stack (42 images). Also, during training, where I don’t need to stitch back the images, I see that I can get ~0.5s per image, so maybe my bottle neck is on the image stiching. I tried writing my own function for it, but it was slower. Then, I tried to write a stitching combined with torch compile, but degraded the quality of the image for some reason and was not much faster… Therefore, I moved to SlidingWindowInferer from the monai library.
Please, let me know your advice. Every little step helps!