Loss Permuting Tensor Along Axis

I have a loss function where I must perform a weighted means squared, where I check every possible permutation of the output along a certain axis. This is because I cannot be certain of the order my network will output its values in, and do not want to bias the output.

I have implemented and commented the following code:

import torch
from itertools import permutations
import torch.nn.functional as F

batch_size = 20
n_vertices = 100
n_features = 5

## Generate a set of mock output and truth data of size B, V, F
y_hat = torch.randn(batch_size, n_vertices, n_features, requires_grad=True)
y = torch.randn(batch_size, n_vertices, n_features, requires_grad=True)

## Activate mock output and truth data  using logsoftmax
y_hat = torch.exp(F.log_softmax(y_hat, dim=-1))
y = torch.exp(F.log_softmax(y, dim=-1))

## Generate a set of mock weight data of size B, V
weight = F.relu(torch.randn(batch_size, n_vertices, requires_grad=True))

## Generate the indices all the possible permutations of the F axis
idxs = list(permutations(np.arange(output.size(-1))))

## Reshape weight factor, output-> B,V,1,1

weight =  weight.view(batch_size, n_vertices, 1, 1)

## Index output tensor ->  output-> B,V,Idx,F
y_hat = y_hat[:, :, idxs]

## Repeat truth tensor ->  output-> B,V,Idx,F
y = y.view(batch_size, n_vertices, -1, n_features).repeat(1, 1, len(idxs), 1)

## Calculate loss
loss = weight*(y_hat - y)**2

## Permute the axes, output-> B,Idxs, V, F
## Flatten the last two axes, output -> B, Idxs, VxF
loss = loss.permute(0,2,1,3).flatten(start_dim=2)

## Calculate the sum across the VxF axis (sum of each permutation)
loss = torch.sum(loss, axis=-1)

## Calculate the minimum loss across the Idx axis (min of each batch) 
loss = torch.min(loss, axis=-1)[0]

## Calculate the mean across the batches
loss = torch.mean(loss)

Can someone confirm this is the correct operation? Are the gradients still calculated in spite of the indexing? Is there a better way of doing what I am trying to achieve?