Hi,
I have trained a model to learn correlation between (audio, image) input pairs. While the original use case is for audio-visual correlation, I also want to use the embeddings from an intermediate layer of this trained model for unimodal similarity (e.g image-image or audio-audio similarity). Is it possible to do this without running a forward pass and if so I would like some guidance on this.
My model code appears below (MMDNet). After I train this model, I load it with the saved weights. Then I want to use it for image similarity as follows. However, this generates an error “DataParallel object has no attribute getImgEmb” even though the model has that method. I could run forward pass and get the embeddings, but the forward pass would require me to input audio. I want to know if I can extract the image embeddings from the visual subnetwork (which is what getImgEmb() is doing), without having to pass audio input.
thank you
#================code for image similarity from pretrained MMDNet=======
#load pretrained model
model = MMDNet(2)
device = torch.device(“cuda” if use_cuda else “cpu”)
num_gpus = torch.cuda.device_count()
gpu_list = list(range(num_gpus))
model = torch.nn.DataParallel(model, device_ids=gpu_list)
model.to(device)
model.load_state_dict(saved_wts)
model.eval()
#get Image embeddings, however this generates an error “DataParallel object has no attribute getImgEmb”
emb1 = model.getImgEmb(src_img)
emb2 = model.getImgEmb(dest_img)
#get euclidean distance between (emb1, emb2) to compute similarity
#====mulimodal correlation model ===#
class MMDNet(nn.Module):
def init(self, n_cls):
super(MMDNet, self).__init__()
self.num_class = n_cls
self.vfeatures = VisNet() #visual subnetwork
self.afeatures = AudNet() #audio subnetwork
self.av_classifier = nn.Linear(1, n_cls)
def forward(self, x_v, x_a):
v_out = self.vfeatures(x_v)
a_out = self.afeatures(x_a)
#euclidean distance between embeddings
av_dist = F.mse_loss(v_out, a_out, reduction='none').mean(1)
out = av_dist.view(av_dist.shape[0], 1)
out = self.av_classifier(out)
return out
def getImgEmb(self, x_v):
“”“Get visual embeddings for a given image from the visual subnetwork”""
v_emb = self.vfeatures(x_v)
return v_emb