Would it be possible to get rid of these two for loops to make my computation faster?
import torch
target = torch.ones(5,1)
target[1]=0
target[4]=0
batch_size = 5
embed_feat_1 = torch.rand(batch_size,512)
embed_feat_2 =torch.rand(batch_size,512)
inputs = torch.rand(batch_size,3,1)
MSELoss = torch.nn.MSELoss()
loss = 0
for q in (range(inputs.shape[0])):
for j in range(inputs.shape[0]):
if q == j or target[q]!=target[j]:
continue
dij_plus = MSELoss(embed_feat_1[q,:], embed_feat_1[j,:]) ** 2
dij_minus = MSELoss(embed_feat_1[q,:], embed_feat_2[q,:]) ** 2
dij_minus = dij_minus + MSELoss(embed_feat_1[j,:], embed_feat_2[j,:]) **2
dij_minus = dij_minus/2
loss = loss + torch.log(torch.exp(-dij_plus)/(torch.exp(-dij_plus) + torch.exp(-dij_minus))) * (-1)
loss
the formula for the loss function is technically:
-torch.log(torch.exp(-dij_plus )/(torch.exp(-dij_plus )+torch.exp(-dij_minus )))