VGG | RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I am training a VGG16 with images of size 224x224 with 4 classes.

device = torch.device("cuda" if torch.cuda.is_available() 
                                  else "cpu")
print(device)
model = models.vgg16(pretrained=True)

for param in model.parameters():
    param.requires_grad = False
    
model.fc = nn.Sequential(nn.Linear(2048, 512),
                                 nn.ReLU(),
                                 nn.Linear(512, 4),
                                 nn.LogSoftmax(dim=1))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.0001)
model.to(device)

And I am training it for 50 epochs:

epochs = 50
steps = 0
running_loss = 0
print_every = 105
train_losses, test_losses, accuracy = [], [], []

for epoch in range(epochs):
    for inputs, labels in trainloader:
        #print('working')
        steps += 1
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        logps = model.forward(inputs)
        loss = criterion(logps, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
        if steps % print_every == 0:
            test_loss = 0
            accuracy = 0
            model.eval()
            with torch.no_grad():
                for inputs, labels in testloader:
                    #print('validating')
                    inputs, labels = inputs.to(device), labels.to(device)
                    logps = model.forward(inputs)
                    batch_loss = criterion(logps, labels)
                    test_loss += batch_loss.item()
                    
                    ps = torch.exp(logps)
                    top_p, top_class = ps.topk(1, dim=1)
                    equals = top_class == labels.view(*top_class.shape)
                    accuracy += torch.mean(equals.type(torch.FloatTensor)).item()

            train_losses.append(running_loss/len(trainloader))
            test_losses.append(test_loss/len(testloader))                    
            print(f"Epoch {epoch+1}/{epochs}.. "
                  f"Train loss: {running_loss/print_every:.3f}.. "
                  f"Test loss: {test_loss/len(testloader):.3f}.. "
                  f"Test accuracy: {accuracy/len(testloader):.3f}")
            running_loss = 0
            model.train()

However, this error message pops up:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

What am I doing wrong? Thanks!

vgg16 does not have a model.fc attribute, but uses model.classifier instead.
Because of that you are currently freezing the complete model and assign the nn.Sequential container to a new and unused attribute called fc.

1 Like

So should it be like this instead?

for param in model.parameters():
    param.requires_grad = False
    
model.classifier = nn.Sequential(nn.Linear(2048, 512),
                                 nn.ReLU(),
                                 #nn.Dropout(0.2),
                                 nn.Linear(512, 4),
                                 nn.LogSoftmax(dim=1))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.classifier.parameters(), lr=0.0001)
model.to(device)

The number of input features for vgg16 is 25088, but besides that the code looks alright.

EDIT:
Sorry, missed an issue.
Since you are using nn.LogSoftmax as the last activation function, you should use nn.NLLLoss as your criterion.

Alternatively you could remove the nn.LogSoftmax and use nn.CrossEntropyLoss, which will internally use F.log_softmax and nn.NLLLoss.

Does this make any sense at all given the size of my images 224x224 and 4 number of classes?

for param in model.parameters():
    param.requires_grad = True
    
model.classifier = nn.Sequential(nn.Linear(25088, 64),
                                 nn.ReLU(),
                                 #nn.Dropout(0.2),
                                 nn.Linear(64, 4),
                                 nn.LogSoftmax(dim=1))
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.classifier.parameters(), lr=0.0001)
model.to(device)

I am trying to replicate last layer of this ResNet50:

for param in model.parameters():
    param.requires_grad = False
    
model.fc = nn.Sequential(nn.Linear(2048, 512),
                                 nn.ReLU(),
                                 nn.Linear(512, 4),
                                 nn.LogSoftmax(dim=1))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.0001)
model.to(device)

I am kind of lost here…

The number of input features of the first linear layer in a “standard” CNN is defined by the flattened activation from the feature extractor (conv-pool layers).
Each image contains 3*224*224=150528 pixels at the beginning. The conv and pooling layers then change the activation shape (spatial and channel-wise) so that the output will have the shape [batch_size, 512, 7, 7]. After flattening this activation, you’ll end up with a tensor of [batch_size, 25088].

Let me know, if you need more information or if I misunderstood your question.