I have a loss function where I must perform a weighted means squared, where I check every possible permutation of the output along a certain axis. This is because I cannot be certain of the order my network will output its values in, and do not want to bias the output.
I have implemented and commented the following code:
import torch from itertools import permutations import torch.nn.functional as F batch_size = 20 n_vertices = 100 n_features = 5 ## Generate a set of mock output and truth data of size B, V, F y_hat = torch.randn(batch_size, n_vertices, n_features, requires_grad=True) y = torch.randn(batch_size, n_vertices, n_features, requires_grad=True) ## Activate mock output and truth data using logsoftmax y_hat = torch.exp(F.log_softmax(y_hat, dim=-1)) y = torch.exp(F.log_softmax(y, dim=-1)) ## Generate a set of mock weight data of size B, V weight = F.relu(torch.randn(batch_size, n_vertices, requires_grad=True)) ## Generate the indices all the possible permutations of the F axis idxs = list(permutations(np.arange(output.size(-1)))) ## Reshape weight factor, output-> B,V,1,1 weight = weight.view(batch_size, n_vertices, 1, 1) ## Index output tensor -> output-> B,V,Idx,F y_hat = y_hat[:, :, idxs] ## Repeat truth tensor -> output-> B,V,Idx,F y = y.view(batch_size, n_vertices, -1, n_features).repeat(1, 1, len(idxs), 1) ## Calculate loss loss = weight*(y_hat - y)**2 ## Permute the axes, output-> B,Idxs, V, F ## Flatten the last two axes, output -> B, Idxs, VxF loss = loss.permute(0,2,1,3).flatten(start_dim=2) ## Calculate the sum across the VxF axis (sum of each permutation) loss = torch.sum(loss, axis=-1) ## Calculate the minimum loss across the Idx axis (min of each batch) loss = torch.min(loss, axis=-1) ## Calculate the mean across the batches loss = torch.mean(loss)
Can someone confirm this is the correct operation? Are the gradients still calculated in spite of the indexing? Is there a better way of doing what I am trying to achieve?