Hey there,
I trained a human portrait segmentation model which I then wrapped in another model so it does alpha blending to replace the background (similar to MS Teams or Zoom).
The segmentation itself is running on a 512x256 img. After segmentation and alpha blending I tried to use torch.nn.functional.interpolate to get it back to at least 720p.
Now to my problem: the segmentation model including alpha blending is extremely efficient it needs like 8ms. But the interpolation brings the model to an inference time of 35ms.
Here’s the code of my wrapper model:
class OverlayWrapperModel(torch.nn.Module):
def _init_(self, seg_model):
super(OverlayWrapperModel, self)._init_()
self.net = seg_model
def forward(self, input_stacked):
# I stacked the webcam img and the overlay it should get
input_img, overlay, alpha = torch.split(input_stacked, [3, 3, 1], dim=1)
# segmentation inference
output_logits = self.net(input_img).sigmoid()
# cutout the mask from alpha channel
alpha[output_logits > 0.5] = 0
# combine webcam img and upscale
return torch.nn.functional.interpolate(overlay * alpha + input_img * (1 - alpha), (720, 1280), mode="nearest")
What confuses me:
- Why is the interpolation making my model so slow?
- is there a better way to replace the background and get 1280x720 output img?
Thanks in advance!
Regards
Kev