Encounter an RuntimeError: CUDNN_STATUS_NOT_SUPPORTED when computing gradients through gradients

I’m trying to implement the paper `Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks’, which involves computing gradients through gradients in a way like this:

net = nn.Conv2d(3, 1, kernel_size=3, stride=1, padding=1)
x1 = Variable(torch.FloatTensor(1, 3, 256, 256).normal_())
x2 = Variable(torch.FloatTensor(1, 3, 256, 256).normal_())
target = Variable(torch.FloatTensor(1, 1, 256, 256).fill_(1))
criterion = nn.BCEWithLogitsLoss()

y = net(x1)
loss = criterion(y, target)
net.zero_grad()
loss.backward(create_graph=True)

alpha = 1e-4
w = net.weight - alpha*net.weight.grad
b = net.bias - alpha*net.bias.grad
z = F.conv2d(x2, w, b, stride=1, padding=1)
loss = criterion(z, target)
net.zero_grad()
loss.backward()

The code above runs fine (on CPU). However, when I move all variables to GPU:

net = nn.Conv2d(3, 1, kernel_size=3, stride=1, padding=1)
net = net.cuda()
x1 = Variable(torch.FloatTensor(1, 3, 256, 256).normal_()).cuda()
x2 = Variable(torch.FloatTensor(1, 3, 256, 256).normal_()).cuda()
target = Variable(torch.FloatTensor(1, 1, 256, 256).fill_(1)).cuda()
criterion = nn.BCEWithLogitsLoss()

y = net(x1)
loss = criterion(y, target)
net.zero_grad()
loss.backward(create_graph=True)

alpha = 1e-4
w = net.weight - alpha*net.weight.grad
b = net.bias - alpha*net.bias.grad
z = F.conv2d(x2, w, b, stride=1, padding=1)
loss = criterion(z, target)
net.zero_grad()
loss.backward()

it turns out to raise an error
RuntimeError: CUDNN_STATUS_NOT_SUPPORTED. This error may appear if you passed in a non-contiguous input.
It’s confusing because I’m sure that there’s no non-contiguous input. Would you know why this happens? Are there any solution?

Thank you in advance.

update:

the error disappears when using an input data of different size. For example, code below works fine:

net = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1)
net = net.cuda()
x1 = Variable(torch.FloatTensor(1, 1, 256, 256).normal_()).cuda()
x2 = Variable(torch.FloatTensor(1, 1, 256, 256).normal_()).cuda()
target = Variable(torch.FloatTensor(1, 1, 256, 256).fill_(1)).cuda()
criterion = nn.BCEWithLogitsLoss()

y = net(x1)
net.zero_grad()
y.backward(target, create_graph=True)

alpha = 1e-4
w = net.weight - alpha*net.weight.grad
b = net.bias - alpha*net.bias.grad
z = F.conv2d(x2, w, b, stride=1, padding=1)
net.zero_grad()
z.backward(target)