`torch.cdist` cannot handle large matrix on GPU

import torch


m = 2500
c = 256

# works fine for cpu
x = torch.randn(1, m, c)
# raise an error: RuntimeError: CUDA error: invalid configuration argument
# x = torch.randn(1, m, c).to("cuda:0")

x.requires_grad = True

dist = torch.cdist(x, x, p=2)

dist.sum().backward()

It’s not OOM. I found some similar links posted in pytorch github. https://github.com/pytorch/pytorch/pull/31593
https://github.com/pytorch/pytorch/pull/31167
I am not sure if these get merged, whether it can deal with my case.

Is there some workaround for this ?

I see @ptrblck in the first link.:wink:

My PR should fix your issue. I’ll continue working on it this week. :wink:

It will be very nice. Hope it can be solved ASAP, since I really need this to push my work forward.