Torch.stack cryptic error when using out= parameter


(AI) #1

I am having a really odd issue with torch.stack where if I write something like:

a = torch.ones(3, 100, 100)
b = torch.zeros(2, 3, 100, 100)
torch.stack([a, a], dim=0, out=b)

it fails with

RuntimeError: cat(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

I can’t find any documentation related to this error. This error originally occurred while using the default_collate function in pytorch, however, I am able to recreate it using the lines above.


(Simon Wang) #2

The code snippet you posted works for me.


(AI) #3

What version of pytorch do you have? I have

torch (0.4.0)
torchvision (0.2.1)
python (3.5.2)


(AI) #4

You are right. My example does pass, and that’s because the grad_fn attribute is None. Apologies for putting the wrong example. To get that error you must put a through a nn.Conv2d, that will result in torch.stack failing. However, if you use nn.functional.conv2d it will not populate a.grad_fn and therefore it will not fail.


(Alban D) #5

Hi,

I think the problem is just that you cannot use the out= keyword argument within the autograd engine.
In the example above, since you don’t require grads, this does not use the autograd and so works well.
I guess when you use the functional version of conv, you don’t require grads and so it works as well.
When using the nn version, the conv parameters require grads by default and so will fail.


(Soufiane Belharbi) #6

Thanks!
I run to the same problem when using torch.sort(..., out=(sorted_c, indices)), and I get the error RuntimeError: sort(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

In the documentation of torch.sort() it does not mention that out= can’t be used when one of the out arguments requires grad. Probably, it would be better to add this information to the documentation (if it is not already added somewhere).

I use Python 3.7.1 and Pytorch 1.0.0.

Thanks!


(Alban D) #7

Hi,

This is not specific to sort(). out= is supported for no function.