Hi Pytorch

I’m trying to implement a weighted distance function for my loss function. Basically I’m creating a pairwise distance matrix dd between my two inputs X (n x 3 x 3) and Y (n x 3 x 3) of size n x n. My distance is basically taking the norm of the final dimension, and summing them. So dd = torch.sum(torch.norm(x-y,2,-1)). The thing is I want this distance to be weighted, so my idea was to do something like dd = 2*torch.sum(torch.norm(x-y,2,-1)) + torch.max(torch.norm(x-y,2,-1))[0] - torch.min(torch.norm(x-y,2,-1))[0]. This in effect calculates a distance which is 3*largest + 2*middle + 1*smallest.

My question is about autograd, namely, when calculating the gradients wont these be chain ruled out, just the same as when you double a loss, or something like that? How can I ensure this isn’t the case? How can i verify if it is? Would instantiating a new variable for these max and min values be enough?

The reason I’m trying to do this is the network seems to be settling in a local minimum where only 2 of the 3 vectors in the final dimension are correctly being reconstructed, so I wanted to give the largest loss more weight. Any help with this question would be greatly appreciated.