Interpolation of network output extremely slow

Hey there,

I trained a human portrait segmentation model which I then wrapped in another model so it does alpha blending to replace the background (similar to MS Teams or Zoom).

The segmentation itself is running on a 512x256 img. After segmentation and alpha blending I tried to use torch.nn.functional.interpolate to get it back to at least 720p.

Now to my problem: the segmentation model including alpha blending is extremely efficient it needs like 8ms. But the interpolation brings the model to an inference time of 35ms.

Here’s the code of my wrapper model:

class OverlayWrapperModel(torch.nn.Module):
    def _init_(self, seg_model):
        super(OverlayWrapperModel, self)._init_()
        self.net = seg_model
        
    def forward(self, input_stacked):
        # I stacked the webcam img and the overlay it should get 
        input_img, overlay, alpha = torch.split(input_stacked, [3, 3, 1], dim=1)
        
        # segmentation inference
        output_logits = self.net(input_img).sigmoid()
        
        # cutout the mask from alpha channel
        alpha[output_logits > 0.5] = 0
        
        # combine webcam img and upscale
        return torch.nn.functional.interpolate(overlay * alpha + input_img * (1 - alpha),  (720, 1280), mode="nearest")

What confuses me:

  • Why is the interpolation making my model so slow?
  • is there a better way to replace the background and get 1280x720 output img?

Thanks in advance!

Regards
Kev

A few questions:
Is the interpolation being done on the GPU or CPU?
If the model is being done on the GPU, have you explored jit scripting the model e.g., torch.jit.script — PyTorch 1.13 documentation as it could potentially fuse the pointwise operations done before the interpolation.

If you can tolerate a bleeding-edge user experience, I would also check if Torch 2.0’s compile function could also fuse some pointwise ops and offer some speedup as well:

Sorry for the missing info:

Everything is being done on CPU.
I’ll try torch’s 2.0 compile function and report back the improvements.

Thanks for your help!

If everything is being done on the CPU, I would be skeptical of compile/graph-mode improving performance. It may be that interpolation is just an expensive operation on CPU due to bandwidth requirements.