Embedding from segmentation model

I’m trying to get mean embedding from finetuned torchvision.models.segmentation.fcn_resnet50 model.
One approach that I’ve tested is to swap classifier with nn.Identity layer (returns tensor of shape [batch_size, feature_num, h, w]) and use torch.mean on inference:

model_embedder.classifier = Identity()
outputs = torch.mean(model_embedder(inputs), dim=[2, 3])

But I’ve noticed that forward() method of torchvision segmentation models is like:

x = features["out"]
x = self.classifier(x)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
result["out"] = x

And use of F.interpolate() on features may lead to unexpected results (memory leakage and segmentation fault specifically :slight_smile: )

I’ve also tried to add AvgPooling layer to model directly:

class EmbeddingAvgPooling(nn.Module):
    def __init__(self):
        super(EmbeddingAvgPooling, self).__init__()
    def forward(self, x):
        return torch.mean(x, dim=[2, 3])
model_embedder.classifier = EmbeddingAvgPooling()

But of course it’s not connected with F.interpolate() application problem in any way.
(In this setup the error is not segmentation fault but rather size mismatch from F.interpolate(), which is expectable).

What can I do to fix this problem? Do I need to overwrite forward() method of the model (how?) or maybe other variants are possible?

P.S. I understand that use of the segmentation model as embedder might be not the greatest idea, but currently trying to test specifically this approach.

For me the solution was to just use model.backbone(x) instead of ‘model(x)’ on inference:

outputs = torch.mean(model_embedder.backbone(inputs), dim=[2, 3])