I’m trying to get mean embedding from finetuned torchvision.models.segmentation.fcn_resnet50
model.
One approach that I’ve tested is to swap classifier with nn.Identity
layer (returns tensor of shape [batch_size, feature_num, h, w]) and use torch.mean
on inference:
model_embedder.classifier = Identity()
outputs = torch.mean(model_embedder(inputs), dim=[2, 3])
But I’ve noticed that forward()
method of torchvision segmentation models is like:
...
x = features["out"]
x = self.classifier(x)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
result["out"] = x
...
And use of F.interpolate()
on features may lead to unexpected results (memory leakage and segmentation fault specifically )
I’ve also tried to add AvgPooling
layer to model directly:
class EmbeddingAvgPooling(nn.Module):
def __init__(self):
super(EmbeddingAvgPooling, self).__init__()
def forward(self, x):
return torch.mean(x, dim=[2, 3])
model_embedder.classifier = EmbeddingAvgPooling()
But of course it’s not connected with F.interpolate()
application problem in any way.
(In this setup the error is not segmentation fault but rather size mismatch
from F.interpolate()
, which is expectable).
What can I do to fix this problem? Do I need to overwrite forward()
method of the model (how?) or maybe other variants are possible?
P.S. I understand that use of the segmentation model as embedder might be not the greatest idea, but currently trying to test specifically this approach.