I have a loss function where I must perform a weighted means squared, where I check every possible permutation of the output along a certain axis. This is because I cannot be certain of the order my network will output its values in, and do not want to bias the output.
I have implemented and commented the following code:
import torch
from itertools import permutations
import torch.nn.functional as F
batch_size = 20
n_vertices = 100
n_features = 5
## Generate a set of mock output and truth data of size B, V, F
y_hat = torch.randn(batch_size, n_vertices, n_features, requires_grad=True)
y = torch.randn(batch_size, n_vertices, n_features, requires_grad=True)
## Activate mock output and truth data using logsoftmax
y_hat = torch.exp(F.log_softmax(y_hat, dim=-1))
y = torch.exp(F.log_softmax(y, dim=-1))
## Generate a set of mock weight data of size B, V
weight = F.relu(torch.randn(batch_size, n_vertices, requires_grad=True))
## Generate the indices all the possible permutations of the F axis
idxs = list(permutations(np.arange(output.size(-1))))
## Reshape weight factor, output-> B,V,1,1
weight = weight.view(batch_size, n_vertices, 1, 1)
## Index output tensor -> output-> B,V,Idx,F
y_hat = y_hat[:, :, idxs]
## Repeat truth tensor -> output-> B,V,Idx,F
y = y.view(batch_size, n_vertices, -1, n_features).repeat(1, 1, len(idxs), 1)
## Calculate loss
loss = weight*(y_hat - y)**2
## Permute the axes, output-> B,Idxs, V, F
## Flatten the last two axes, output -> B, Idxs, VxF
loss = loss.permute(0,2,1,3).flatten(start_dim=2)
## Calculate the sum across the VxF axis (sum of each permutation)
loss = torch.sum(loss, axis=-1)
## Calculate the minimum loss across the Idx axis (min of each batch)
loss = torch.min(loss, axis=-1)[0]
## Calculate the mean across the batches
loss = torch.mean(loss)
Can someone confirm this is the correct operation? Are the gradients still calculated in spite of the indexing? Is there a better way of doing what I am trying to achieve?