Autograd.grad dimension error

I am currently implementing RELAX gradient estimate for stochastic function, in which gradient of variance of gradient should be computed, and after updating from 0.3 to master I’ve got following error message:

Expected 4-dimensional input for 4-dimensional weight [10], 
but got input of size [11, 1, 32, 32] instead

In this particular line:

grad_phi = torch.autograd.grad([val], phi)[:len(phi)]

Here’s my code for forward/backward pass, where error appears:

input, target = Variable(batch_x), Variable(batch_y).type(torch.LongTensor)
mean, log_std = net(input)
        
dist = Normal(mean=mean, std=log_std.exp())
rsample = dist.rsample()
sample = rsample.detach()

f = hamming_loss(sample, target)
c = reduce_net(criterion(rsample, target))
log_prob = dist.log_prob(sample).sum()

first_term = torch.autograd.grad([log_prob], theta, create_graph=True)[:len(theta)]
second_term = list(torch.autograd.grad([c], theta, create_graph=True)[:len(theta)])

val = 0.
for grad_id in range(len(second_term)):
    theta[grad_id].grad = first_term[grad_id] * (f - c) + second_term[grad_id]
    val += (theta[grad_id].grad ** 2).sum()
grad_phi = torch.autograd.grad([val], phi)[:len(phi)]
for j, param in enumerate(phi): param.grad = grad_phi[j]

Here is code for modules and hamming loss:

class NormalParametersNet(nn.Module):
    def __init__(self):
        super(NormalParametersNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, 5)
        self.conv2 = nn.Conv2d(10, 20, 5)

        self.fc1_mean = nn.Linear(5 * 5 * 20, 140)
        self.fc2_mean = nn.Linear(140, num_classes)

        self.fc1_log_std = nn.Linear(5 * 5 * 20, 140)
        self.fc2_log_std = nn.Linear(140, num_classes)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 5 * 5 * 20)
        
        mean = F.relu(self.fc1_mean(x))
        mean = self.fc2_mean(mean)

        log_std = F.relu(self.fc1_log_std(x))
        log_std = self.fc2_log_std(log_std)

        return mean, log_std


class ApproxNet(nn.Module):
    def __init__(self, num_neurons=20):
        super(ApproxNet, self).__init__()
        self.num = num_neurons
        self.linear1 = nn.Linear(1, self.num)
        self.linear2 = nn.Linear(self.num, self.num)
        self.linear3 = nn.Linear(self.num, 1)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))

        return self.linear3(x)


net = NormalParametersNet()
reduce_net = ApproxNet()


def hamming_loss(output, target):
    return 1 - (output.data.max(1)[1] == target.data).sum() / target.size(0)

Normal is class from torch.distributions. Input has shape of 11 x 1 x 32 x 32. So can somebody point out the mistake.

Forgot to mention:

criterion = nn.CrossEntropyLoss()

theta = list(net.parameters())
phi = list(reduce_net.parameters())

if i change input in the following way:

input = Variable(batch_x, requires_grad=True)

it fails in .zero_grad() with Can't detach views in-place. Use detach() instead, however if comment out zero grad everything begins to work fine.

The error msg isn’t very clear. I apologize for that since I wrote it. I have improved it in a recent PR.

That said, it is weird that conv got a 1d weight tensor with your code. If you don’t mind, could post or send me a runnable script that I can reproduce and debug the issue? You can replace input and target Variables with randomly generated data if needed. :slight_smile: Thanks!

Sorry for late response, i’ve made a script which is able to reproduce this error.

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


from torch.autograd import Variable
from torch.autograd import grad


torch.manual_seed(42)
num_classes = 26

class UnaryNet(nn.Module):
    def __init__(self):
        super(UnaryNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, 5)
        self.conv2 = nn.Conv2d(10, 20, 5)

        self.fc1_mean = nn.Linear(5 * 5 * 20, 140)
        self.fc2_mean = nn.Linear(140, num_classes)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 5 * 5 * 20)
        
        mean = F.relu(self.fc1_mean(x))
        mean = self.fc2_mean(mean)

        return mean

net = UnaryNet()

phi = Variable(torch.Tensor([1.]), requires_grad=True)
theta = net.parameters()

input, target = Variable(torch.randn(5,1,32,32)), Variable(torch.rand(5).long())
criterion = nn.CrossEntropyLoss()

loss1 = criterion(net(input), target) * phi
theta_grad = torch.autograd.grad([loss1], theta, create_graph=True)

loss2 = 0.
for grad in theta_grad: loss2 += grad.pow(2).sum()
phi_grad = torch.autograd.grad([loss2], [phi])

@apaszke Could this be related to the autograd.grad bug?

Which bug do you have in mind? It looks more like an incorrect backward definition for convs

I don’t know the details, but I just saw Priya mentioning that you are working on a autograd.grad bug. So I asked in case this is the same issue.

Likely. I just thought the conv backwards are well tested so such standard things shouldn’t fail. I’ll gdb it then.

Are there any updates?

I’m tracking it down. Sorry I was working on some other tasks before. Here is a MWE we found. I’ll let you know if there is any update.

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


from torch.autograd import Variable
from torch.autograd import grad


conv = nn.Conv2d(1, 10, 5).cuda()

input = Variable(torch.randn(1,1,32,32).cuda())

loss1 = conv(input).sum()
t, = torch.autograd.grad(loss1, conv.bias, create_graph=True)

loss2 = t.sum()
loss2.backward()

Will be fixed once this PR merges: https://github.com/pytorch/pytorch/pull/4812.

Hi, I ran into a similar error while playing around with the [https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html](Transfer Learning Tutorial). The reason for posting this is that the error message doesn’t seem to be related to the error itself, at least the listed values that are named in it do not make sense to me.
Error itself was caused by an incorrect tensor shape as input to the model.
For this particular case I found a workaround, though maybe reporting it may be helpful. Below is what I did and the error. I’m using pytorch v0.4.1.

  1. After training the model I’d like to feed it with a single image. Therefore I define a single image path and a custom data transformer and transform the image to its tensor equivalent.
  2. When inputting this tensor to the model ([3, 224, 224]), the batch size dimension is missing, giving me the error (bottom of the post).
  3. I did a workaround by ‘unsqueezing’ the tensor to a correct shape ([1, 3, 224, 224).
  4. Inputting the unsqueezed tensor gives a prediction result as expected.
from PIL import Image

# Get single image
img_path = "/hymenoptera_data/train/ants/0013035.jpg"
img = Image.open(open(img_path, 'rb'))

# Define transformation
composed = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

img_trfd = composed(img)
img_trfd.size()              # output: torch.Size([3, 224, 224])
model_tf(img_trfd)           # GIVES ERROR AS BELOW

img_trfd_unsq = img_trfd.unsqueeze(0)
img_trfd_unsq.size()         # output: torch.Size([1, 3, 224, 224])
model_ft(img_trfd_unsq)      # Predicting results correctly

Error at incorrect input tensor size:

Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "/home/bloks/Projects/Sentriq/sentriq_venv/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/bloks/Projects/Sentriq/sentriq_venv/lib/python3.5/site-packages/torchvision/models/resnet.py", line 139, in forward
    x = self.conv1(x)
  File "/home/bloks/Projects/Sentriq/sentriq_venv/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/bloks/Projects/Sentriq/sentriq_venv/lib/python3.5/site-packages/torch/nn/modules/conv.py", line 301, in forward
    self.padding, self.dilation, self.groups)
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got input of size [3, 224, 224] instead
1 Like

The error message makes perfect sense. You are missing the batch dimension. Call .unsqueeze_(0) before activating.

thank you, that help me, but when i have the error
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 11, 11], but got 5-dimensional input of size [20, 16, 3, 224, 224] instead

if i try call unsqueeze(0) i get [ 1,20,16,224,224] how can i reduce the dimension to 4 ??