Torch.nn modules dont accept 0-batch dim tensors on gpu

torch.nn modules accept 0-batch dim tensors on cpu, however I found that torch.nn modules don’t accept 0-batch dim tensors on gpu. Is it reasonable?
I run a piece of code on cpu and gpu, the result is {‘res_cpu’: (tensor([], size=(0, 1, 1)), tensor([], size=(0, 1, 1), dtype=torch.int64)), ‘err_gpu’: ‘ERROR:CUDA error: invalid configuration argument’}. I think cpu’s result should be consistent with gpu’s result. I wonder if my needs are reasonable😂

Could you post a code snippet showing this error?
Some layers seem to work, e.g. as seen here:

conv = nn.Conv2d(3, 3, 3).cuda()
x = torch.randn(0, 3, 24, 24).cuda()

out = conv(x)
print(out)
# tensor([], device='cuda:0', size=(0, 3, 22, 22),
#        grad_fn=<ConvolutionBackward0>)

and you are right that no CUDA error should be raised (even if I don’t really understand the use case of 0-dim tensors :wink: ).

The code snippet is as follows:

results = dict()
import torch
arg_1 = 2
arg_2 = 50
arg_3 = True
arg_class = torch.nn.MaxPool1d(arg_1,stride=arg_2,return_indices=arg_3,)
arg_4 = torch.rand([0, 1, 49], dtype=torch.float32)
try:
  results["res_cpu"] = arg_class(arg_4)
except Exception as e:
  results["err_cpu"] = "ERROR:"+str(e)
arg_class = arg_class.cuda()
arg_5 = arg_4.clone().cuda()
try:
  results["res_gpu"] = arg_class(arg_5)
except Exception as e:
  results["err_gpu"] = "ERROR:"+str(e)
print(results)

The result is : {‘res_cpu’: (tensor([], size=(0, 1, 1)), tensor([], size=(0, 1, 1), dtype=torch.int64)), ‘err_gpu’: ‘ERROR:CUDA error: invalid configuration argument’}.
My torch version is 1.8.1, cuda version is 1.11.0. I’m not sure whether it runs differently in latest pytorch w(゚Д゚)w

The code snippet works for me and prints:

{'res_cpu': (tensor([], size=(0, 1, 1)), tensor([], size=(0, 1, 1), dtype=torch.int64)), 'res_gpu': (tensor([], device='cuda:0', size=(0, 1, 1)), tensor([], device='cuda:0', size=(0, 1, 1), dtype=torch.int64))}

so your PyTorch installations is probably too old.

Get it. I will install latest pytorch. Thanks :stuck_out_tongue_closed_eyes: