Unclear CUDNN RuntimeError during backward

I’m trying to implement the Model Agnostic Meta-Learning algorithm in PyTorch based on the TensorFlow code. To summarize:

  1. Compute the gradients of a model parametrized by Theta based on the loss from some training samples.
  2. Compute Theta_ = Theta - lr*Theta_grad
  3. Compute loss of the model when parametrized by Theta_ on some test samples.
  4. Compute the gradients all the way back to Theta

Here is the code I wrote:

train_output = forward(train_images, weights)
train_loss = F.cross_entropy(train_output, train_labels)

grads = th.autograd.grad(train_loss, weights.values(), create_graph=True)

gradients = dict(zip(weights.keys(), grads))
fast_weights = Munch(dict(zip(weights.keys(), [weights[key] - args.update_lr * gradients[key] for key in weights.keys()])))

test_output = forward(test_images, fast_weights)
test_loss = F.cross_entropy(test_output, test_labels)

temp_grad = th.autograd.grad(test_loss, fast_weights.b5, retain_graph=True)
new_grads = th.autograd.grad(fast_weights.b5, weights.b5, grad_outputs=temp_grad)

test_loss.backward()

This code works fine when executed on CPU but does not work when using GPU. I use Pytorch 0.2.0. Complete error message:

Traceback (most recent call last):
  File "main_f.py", line 260, in <module>
    test_loss.backward()
  File "/home/lib/python3.5/site-packages/torch/autograd/variable.py", line 156, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
  File "/home/lib/python3.5/site-packages/torch/autograd/__init__.py", line 98, in backward
    variables, grad_variables, retain_graph)
RuntimeError: CUDNN_STATUS_NOT_SUPPORTED. This error may appear if you passed in a non-contiguous input.

Why does this happen?

Related Code:

weights = Munch()

weights.W1, weights.b1 = init_conv(1, args.num_filters, (3, 3))
weights.W2, weights.b2 = init_conv(args.num_filters, args.num_filters, (3, 3))
weights.W3, weights.b3 = init_conv(args.num_filters, args.num_filters, (3, 3))
weights.W4, weights.b4 = init_conv(args.num_filters, args.num_filters, (3, 3))
weights.W5, weights.b5 = init_fc(args.num_filters, args.num_classes)

bns = Munch()

bns.bn1 = nn.BatchNorm2d(args.num_filters)
bns.bn2 = nn.BatchNorm2d(args.num_filters)
bns.bn3 = nn.BatchNorm2d(args.num_filters)
bns.bn4 = nn.BatchNorm2d(args.num_filters)

if args.cuda:
    weights = Munch({k: w.cuda() for k, w in weights.items()})
    bns = Munch({k: bn.cuda() for k, bn in bns.items()})

def conv_block(input, weight, bias, bn):
    out = F.conv2d(input, weight, bias, padding=1)
    out = bn(out)
    out = F.relu(out)
    return F.max_pool2d(out, 2)

def forward(input, weights):
    out = conv_block(input, weights.W1, weights.b1, bns.bn1)
    out = conv_block(out, weights.W2, weights.b2, bns.bn2)
    out = conv_block(out, weights.W3, weights.b3, bns.bn3)
    out = conv_block(out, weights.W4, weights.b4, bns.bn4)
    out = out.view(-1, 64)
    return F.linear(out, weights.W5, weights.b5)