import torch
m = 2500
c = 256
# works fine for cpu
x = torch.randn(1, m, c)
# raise an error: RuntimeError: CUDA error: invalid configuration argument
# x = torch.randn(1, m, c).to("cuda:0")
x.requires_grad = True
dist = torch.cdist(x, x, p=2)
dist.sum().backward()
It’s not OOM. I found some similar links posted in pytorch github. https://github.com/pytorch/pytorch/pull/31593
https://github.com/pytorch/pytorch/pull/31167
I am not sure if these get merged, whether it can deal with my case.
Is there some workaround for this ?