Doing backward with torch.autograd.grad causes "cuDNN requires contiguous weight tensor"?

I use torch.autograd.grad to compute a extra loss regarding gradient (like gradient penalty in WGAN-GP):

gradient = torch.autograd.grad(
    output_value, input_data,
    create_graph = True, retain_graph = True, only_inputs = True,
)[0].contiguous()
B = gradient.size(0)
gradient = gradient.view(B, -1)
grad_norm = gradient.norm(2, dim = 1)
ge_loss = - grad_norm.mean()

Then doing backward with the gradient loss causes:

Traceback (most recent call last):
  File "train_stage3_sup_only.py", line 223, in <module>
    main()
  File "train_stage3_sup_only.py", line 124, in main
    total_loss.backward()
  File "/mnt/lustre/zhenghuabin/anaconda3/envs/py35/lib/python3.5/site-packages/torch/autograd/variable.py", line 167, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
  File "/mnt/lustre/zhenghuabin/anaconda3/envs/py35/lib/python3.5/site-packages/torch/autograd/__init__.py", line 99, in backward
    variables, grad_variables, retain_graph)
RuntimeError: cuDNN requires contiguous weight tensor

I suppose the error rises because some variable throughout the computational graph is not contiguous. Then I try to check contiguousness with functions bellow:

def backtrace_graph(variable):
    if variable.is_contiguous() is False:
        print('%s,' % variable.is_contiguous(), variable.size())
    if variable.grad_fn is None: return
    for nf, _ in variable.grad_fn.next_functions:
        backtrace_nf(nf)

def backtrace_nf(nf):
    if nf is None: return
    if hasattr(nf, 'variable'):
        backtrace_graph(nf.variable)
    else:
        for nfnf, _ in nf.next_functions:
            backtrace_nf(nfnf)

However, calling backtrace_graph(my_gradient_loss) doesn’t print out any info of False contiguousness.

Any suggestions to fix it?

Seems the error rises due to the combination of “ConvTranspose + autograd.grad + backward”. Refer to https://github.com/pytorch/pytorch/issues/5044 for minimal test script.

1 Like