Speeding up image inference

I am building a model to segment microscopy images. My lab in question produces a very high volume of images, thus analysis time is an important factor for us. Thus, I builded a simple UNet, which does a good job at segmenting. However, I see the inference quite slow. Thus, I was wondering if someone could double check my code to see if I am doing something wrong so that slows down my code by a significant factor:

Here is the code for my inference function:

    def predict(self, image):
        """
        Predicts the segmentation mask for the given image. It can handle 2D images or a stack of 2D images.

        Args:
            image (numpy.ndarray): The input image or image stack.

        Returns:
            numpy.ndarray: The segmentation mask or stack of segmentation masks.
        """
        
        image = normalize_percentile(image)

        # Check if the image is a stack
        if len(image.shape) == 3:
            # Store the original image size and number of slices
            original_size = image.shape[1:]
            num_slices = image.shape[0]

            img_tensor = torch.from_numpy(image).unsqueeze(1).to(self.device)  # Shape: [Batch, Channels, H, W]
            if self.half_precision:
                img_tensor = img_tensor.half()  # Convert input to half-precision

            inferer = SlidingWindowInferer(roi_size=self.patch_size, sw_batch_size=1, overlap=self.overlap_ratio)

            with torch.no_grad():
                output_mask = inferer(img_tensor, self.model)

            output_mask = output_mask.squeeze(0).cpu().numpy() 

        else:
            img_tensor = torch.from_numpy(image).unsqueeze(0).unsqueeze(0).to(self.device)
            if self.half_precision:
                img_tensor = img_tensor.half()
            inferer = SlidingWindowInferer(roi_size=self.patch_size, sw_batch_size=350, overlap=self.overlap_ratio)

            with torch.no_grad():
                output_mask = inferer(img_tensor, self.model)

            output_mask = output_mask.squeeze(0).squeeze(0).cpu().numpy()

            # Free up tensors
            del img_tensor, image  
            gc.collect() 
            empty_gpu_cache(self.device)

        return output_mask

At the moment, it takes ~4s per image, which is ~3min per stack (42 images). Also, during training, where I don’t need to stitch back the images, I see that I can get ~0.5s per image, so maybe my bottle neck is on the image stiching. I tried writing my own function for it, but it was slower. Then, I tried to write a stitching combined with torch compile, but degraded the quality of the image for some reason and was not much faster… Therefore, I moved to SlidingWindowInferer from the monai library.

Please, let me know your advice. Every little step helps!

Did you already profile your code with e.g Nsight Systems to narrow down the actual bottlenecks?
If so, were you able to isolate the code parts slowing down the execution?

Hi @ptrblck, I did all the benchmarking with the time module, not very familiar with the Nsight System. The issue is at the moment, I am micro-testing the model on my local computer, an M2 MacBook, but inference will be done on a server with an RTX 3090, a TITANX or an Intel Xeon Platinium, so it is a bit difficult to extrapolate. Thus, many of the variables are still open. I was just wondering if there were any blunders on my code that I could be missing.

Generally, it’s not a good idea to move the tensors back to the CPU as this will synchronize your code:

output_mask = output_mask.squeeze(0).cpu().numpy() 

I don’t know how output_mask is used and if the numpy array is strictly needed, but if possible keep the data on the GPU as long as possible.
Besides that you might want to check our performance guide.