Cuda Error THCTensorScatterGather.cu:70 when using backward() on gradient penalty

That’s what I had read in other discussions, but the strange thing is that changing a line in my loss function from:
torch.nn.functional.pairwise_distance(...)
to:
1 - torch.nn.functional.cosine_similarity(...)
avoids this device-side asserts error (although it makes the loss go to -inf after a few iterations). The rest of the code (where indexing problems might appear) is exactly the same.