I’ve defined a custom loss function for a regression problem but it contains a loop that is very slow and I was wondering if there is a better way to do this.
I generate a prediction and then I want to find the vector in a table (i.e. embedding layer) which is closest to the residual. I use a loop for this but this is very slow. I wanted to use torch.Tensor.apply_ but this is for cpu only, not gpu. I discovered torch.func.vmap but I’m not sure how to use this over both the table and a minibatch (which seems to be its primary purpose). I should also mention I want to update the table during training so perhaps I need to use requires_grad = True in torch.arange?
Thank you in advance for any advice!
def lossFunction(self, prediction, label):
residual = (label - prediction)
for i in torch.arange(self.nCentroids, device='cuda'):
thisNorm = torch.linalg.vector_norm(residual - self.embedding(i))
if i == 0:
bestNorm = thisNorm
bestIndex = 0
elif thisNorm < bestNorm:
bestNorm = thisNorm
bestIndex = i
thisNorm = torch.linalg.vector_norm(prediction - self.embedding(i))
return thisNorm
If you can package your embedding-layer table as a pytorch tensor, you can find the
nearest vector with pytorch tensor operations, avoiding an explicit python loop.
Here, self.embedding looks like a function that takes an index (packaged as a
zero-dimensional pytorch tensor). Let me assume that embedding is or can be packaged
as a tensor. In particular, let prediction and residual be vectors of length n and
embedding be a tensor of shape [n, nCentroids]. You can then use a single call to argmin() instead of the explicit loop.
This appears to be a typo – I assume that this should be:
Note, we will use broadcasting to compare the residual vector with each of the columns
of the embedding tensor. That is, (residual - embedding).shape) is [n, nCentroids].
This would be the typical way to do it. (There would be a connotation that self.embedding
is naturally part of your model, but this is not a requirement.) But you can also train a tensor
that carries requires_grad = True without it being a Parameter.
Take a look at the documentation for linalg.vector_norm() and broadcasting semantics,
in particular how the dim argument works for vector_norm(). Print out the shapes for residual and self.table and see if everything lines up correctly.