def FSP(layer1, layer2):
b = layer1.shape[0]
h2 = layer2.shape[2]
w2 = layer2.shape[3]
m = layer1.shape[1]
n = layer2.shape[1]
mid = torch.nn.functional.interpolate(layer1, [h2, w2], mode='bicubic', align_corners=True)
F3 = torch.zeros(b, m, n)
for batch in range(b):
for i in range(m):
for j in range(n):
F3[batch, i, j] = torch.mean(torch.mul(mid[batch, i, :, :], layer2[batch, j, :, :]))
return F3
Actually, this is what I am trying to implement, but without using loops.