Understanding cdist() function

Probably yes, and these lines are probably the right ones.

In your code other methods, such as torch.cat will create contiguous tensors as seen here:

a = torch.randn(1, 1).expand(10, 10)
print(a.is_contiguous())
> False

b = torch.randn(10, 10)
print(b.is_contiguous())
> True

c = torch.cat((a, b), dim=1)
print(c.is_contiguous())
> True

The main issue, is that your data is too large for the applied operations, as at least some of them work on contiguous tensors, which will create the memory increase.

For mixed-precision training, I would recommend to install the nightly and use native amp as described here.

1 Like