AttributeError: 'tuple' object has no attribute 'dim' error Transfer learning inception_v3

I’m trying to classify my images using transfer learning with inception_v3 and having an error
RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM
What would be the reason of this error ?
My transform


transform = transforms.Compose([ transforms.CenterCrop(1000), transforms.Resize((299,299)),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

My model implementation

# Use GPU if it's available
from collections import OrderedDict
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.inception_v3(pretrained=True)

# Freeze parameters so we don't backprop through them
for param in model.parameters():
    param.requires_grad = False
    
classifier = nn.Sequential(OrderedDict([
                          ('fc1', nn.Linear(2048, 500)),
                          ('relu', nn.ReLU()),
                          ('fc2', nn.Linear(500, 2)),
                          ('output', nn.LogSoftmax(dim=1))
                          ]))
    
model.classifier = classifier

criterion = nn.NLLLoss()

# Only train the classifier parameters, feature parameters are frozen
optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)

model.to(device);

and the full error I am having

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-11-912dafb06827> in <module>
     12 
     13         logps = model.forward(inputs)
---> 14         loss = criterion(logps, labels)
     15         loss.backward()
     16         optimizer.step()

~/anaconda3/envs/pytorch10/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    487             result = self._slow_forward(*input, **kwargs)
    488         else:
--> 489             result = self.forward(*input, **kwargs)
    490         for hook in self._forward_hooks.values():
    491             hook_result = hook(self, input, result)

~/anaconda3/envs/pytorch10/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    208     @weak_script_method
    209     def forward(self, input, target):
--> 210         return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
    211 
    212 

~/anaconda3/envs/pytorch10/lib/python3.7/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1780     if size_average is not None or reduce is not None:
   1781         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 1782     dim = input.dim()
   1783     if dim < 2:
   1784         raise ValueError('Expected 2 or more dimensions (got {})'.format(dim))

AttributeError: 'tuple' object has no attribute 'dim'

Could you check your input in the training loop?
Based on the error message, it looks like you are trying to pass a tuple instead of a data tensor.
This might be the case, if you used just one return value for the data and target:

for batch in loader:
    data = batch[0]
    target = batch[1]
    ...

In this example this error would be thrown, if you try to pass batch directly to the model instead of data.

1 Like

Isn’t this line doing the same thing for inputs, labels in train_loader: inputs as data and labels as target ?

Yes, should do the same.
Could you post your Dataset definition and training loop so that we could have a look?

Data

num_workers = 0
# how many samples per batch to load
batch_size = 20
# percentage of data set to use as test
test_size = 0.3

transform = transforms.Compose([ transforms.CenterCrop(1000), transforms.Resize((299,299)),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

data_set = dset.ImageFolder(root="data",transform=transform)
dataloader = torch.utils.data.DataLoader(data_set, batch_size=4,shuffle=True,num_workers=2)

# obtain training indices that will be used for test
num_data = len(data_set)
indices = list(range(num_data))
np.random.shuffle(indices)
split = int(np.floor(test_size * num_data))
train_idx, test_idx = indices[split:], indices[:split]

# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
test_sampler  = SubsetRandomSampler(test_idx)

# prepare data loaders
train_loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size,
                                           sampler = train_sampler, num_workers=num_workers)
test_loader  = torch.utils.data.DataLoader(data_set, batch_size=batch_size, 
                                           sampler = test_sampler, num_workers=num_workers)

classes = ('ebrus','suminagashis')

Training loop

epochs = 5
steps = 0
running_loss = 0
print_every = 5
for epoch in range(epochs):
    for inputs, labels in train_loader:
        steps += 1
        # Move input and label tensors to the default device
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        logps = model.forward(inputs)
        loss = criterion(logps, labels)
#         loss.backward()
        optimizer.step()

        running_loss += loss.item()
        

    test_loss = 0
    accuracy = 0
    model.eval()
    with torch.no_grad():
         for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            logps = model.forward(inputs)
            batch_loss = criterion(logps, labels)
                    
            test_loss += batch_loss.item()
                    
                    # Calculate accuracy
            ps = torch.exp(logps)
            top_p, top_class = ps.topk(1, dim=1)
            equals = top_class == labels.view(*top_class.shape)
            accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
                    
    print(f"Epoch {epoch+1}/{epochs}.. "
           f"Train loss: {running_loss:.3f}.. "
           f"Test loss: {test_loss/len(test_loader):.3f}.. "
           f"Test accuracy: {accuracy/len(test_loader):.3f}")
    running_loss = 0
    model.train()

Thanks for the code. It looks alright, so could you add these print statements into the training loop:

inputs, labels = inputs.to(device), labels.to(device)
print(type(inputs))
print(type(labels))
print(inputs.size())
print(inputs.dim())

Getting this output
image

Thanks for the debugging.
It seems the error is related to the output of your model.
By default the inception model returns two outputs, the class logits and the aux_loss.
This is explained in the original inception paper.
In case you don’t need the auxiliary loss, you can create the model using aux_logits=False:

model = models.inception_v3(pretrained=True, aux_logits=False)

this is what I get when I add aux_logits = False

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-16-ffd4a44975ff> in <module>
      3 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
      4 
----> 5 model = models.inception_v3(pretrained=True,aux_logits=False)
      6 
      7 # Freeze parameters so we don't backprop through them

~/anaconda3/envs/pytorch10/lib/python3.7/site-packages/torchvision/models/inception.py in inception_v3(pretrained, **kwargs)
     25             kwargs['transform_input'] = True
     26         model = Inception3(**kwargs)
---> 27         model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google']))
     28         return model
     29 

~/anaconda3/envs/pytorch10/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
    767         if len(error_msgs) > 0:
    768             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 769                                self.__class__.__name__, "\n\t".join(error_msgs)))
    770 
    771     def _named_members(self, get_members_fn, prefix='', recurse=True):

RuntimeError: Error(s) in loading state_dict for Inception3:
	Unexpected key(s) in state_dict: "AuxLogits.conv0.conv.weight", "AuxLogits.conv0.bn.weight", "AuxLogits.conv0.bn.bias", "AuxLogits.conv0.bn.running_mean", "AuxLogits.conv0.bn.running_var", "AuxLogits.conv1.conv.weight", "AuxLogits.conv1.bn.weight", "AuxLogits.conv1.bn.bias", "AuxLogits.conv1.bn.running_mean", "AuxLogits.conv1.bn.running_var", "AuxLogits.fc.weight", "AuxLogits.fc.bias".

the overall goal I have is to choose latest pretrained models and use them on my dataset and compare all of them. Looks like this tutorial covers what I am trying to do usehttps://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
except I want to use the best performing version of the model for the models having more than one version. Looks like we can also cite this webpage on our paper. I’ll reimplement all transfer learning section then come back here and update this thread. Hopefully find a solution.

In that case could you try to pass logps[0] to your loss function? The first entry should be the class logits while the second one should be the aux_logits.

Now having this RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn this means inception_v3 does not have a backprop function right?

getting this error now

ValueError                                Traceback (most recent call last)
<ipython-input-18-bac6a09fefde> in <module>
     26             inputs, labels = inputs.to(device), labels.to(device)
     27             logps = model.forward(inputs)
---> 28             batch_loss = criterion(logps[0], labels)
     29 
     30             test_loss += batch_loss.item()

~/anaconda3/envs/pytorch10/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    487             result = self._slow_forward(*input, **kwargs)
    488         else:
--> 489             result = self.forward(*input, **kwargs)
    490         for hook in self._forward_hooks.values():
    491             hook_result = hook(self, input, result)

~/anaconda3/envs/pytorch10/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    208     @weak_script_method
    209     def forward(self, input, target):
--> 210         return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
    211 
    212 

~/anaconda3/envs/pytorch10/lib/python3.7/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1782     dim = input.dim()
   1783     if dim < 2:
-> 1784         raise ValueError('Expected 2 or more dimensions (got {})'.format(dim))
   1785 
   1786     if input.size(0) != target.size(0):

ValueError: Expected 2 or more dimensions (got 1)

Could you print the shape out logps[0]? It should be [batch_size, nb_classes].

I also just realized, that you are assigning your Sequential classifier module to model.classifier.
If you are using inception_v3, you should use model.fc instead.

Here is a minimal code snippet which should work:

model = models.inception_v3(pretrained=True, aux_logits=True)
criterion = nn.CrossEntropyLoss()

x = torch.randn(2, 3, 299, 299)
target = torch.empty(2, dtype=torch.long).random_(1000)

output = model(x)
loss = criterion(output[0], target)
loss.backward()

print(model.fc.weight.grad)

Could you try to compare your code to this snippet?

1 Like

Hi @ptrblck,
Thanks a lot for the help! Now my model is training. Here are the changes I made after your last reply:

  • changed model.classifier to model.fc (Is this information included in pytorch documentation?)
fc = nn.Sequential(OrderedDict([
                          ('fc1', nn.Linear(2048, 500)),
                          ('relu', nn.ReLU()),
                          ('fc2', nn.Linear(500, 2)),
                          ('output', nn.LogSoftmax(dim=1))
                          ]))
    
model.fc = fc
  • Changed my criterion from nn.NLLLoss() to nn.CrossEntropyLoss() (I don’t think this is relevant but I’ll check loss documentation)

  • Changed loss = criterion(logps, labels) to loss = criterion(logps[0], labels)

First and the third changes according to your suggestions helped to start training the network.

Hi Aysin,
I’m glad it’s working now!

  • You would have to print the model or look at the source to see all attributes.
  • If you are using nn.LogSoftmax, you should stick to nn.NLLLoss, as it’ll be applied twice otherwise.

From the paper (page 6):

During training, their loss gets added to the total loss of the network with a discount weight (the losses of the auxiliary classifiers were weighted by 0.3).
~ GoogLeNet training in Caffe by amiralush · Pull Request #1367 · BVLC/caffe · GitHub

Can we do similar to that in PyTorch?

Sure! You would have to set aux_logits=True when initializing the Inception model. In your training loop you will get two outputs, the output of the last linear layer and the auxiliary logits. After calculating the losses, you can weight and sum them to get the final loss.

2 Likes

hello
I got same error
AttributeError: ‘tuple’ object has no attribute ‘dim’
print(‘Starting’)
print(‘Running for’,epochs,‘Epochs’)
#print(‘Batch Size:’, mnist.trainloader.batch_size)
#n_of_batches = int(len(mnist.train_data.data)/mnist.trainloader.batch_size)
#print(‘No of batches per epoch:’, n_of_batches)
losses = []
test_acc = []
for epoch in range(epochs):
loss2 = []
for i,data in enumerate(zip(train_loader, train_labels),0):
pcae.train()
inputs, labels = data
inputs = torch.tensor(inputs, dtype = torch.float32)
labels = torch.tensor(labels, dtype = torch.float32)
inputs, labels = inputs.to(‘cuda’), labels.to(‘cuda’)
optimizer.zero_grad()
out = pcae(inputs)
loss = loss_func(out)
loss.mean().backward()
optimizer.step()
loss2.append(loss.mean().item())
# print("\rEpoch #:",epoch+1, " Batch #:", i+1, “\tImage Likelihood:”,"{0:.2f}".format(-img_lik.item()),
# “\tPart Likelihood:”,"{0:.2f}".format(-part_lik.item()), “\tPrior Sparsity:”,"{0:.2f}".format(prior_sparsity.item()),
# “\tPosterior Sparsity:”,"{0:.2f}".format(posterior_sparsity.item()) end="")
print("\rEpoch #:",epoch+1, " Batch #:", i+1, “\tImage Likelihood:”,"{0:.2f}".format(loss.mean().item()), end="")
l = np.sum(loss2)/(len(loss2))
correct = 0
total = 0
with torch.no_grad():
for data in zip(train_loader, train_labels) :
pcae.eval()
inputs, labels = data
inputs = torch.tensor(inputs, dtype = torch.float32)
labels = torch.tensor(labels, dtype = torch.float32)
inputs, labels = inputs.to(‘cuda’), labels.to(‘cuda’)
outs= pcae(inputs)
norm = torch.norm(outs, dim=2)
predicted = np.argmax(norm.cpu().numpy(),axis=1)
total += labels.size(0)
correct += (predicted == labels.cuda().numpy()).sum().item()
acc = (100 * correct / total)
print(’\rEpoch #:’, epoch+1, ’ Training Loss: ',"{0:.2f}".format(l))
print(‘CapsNet Testing Accuracy on MNIST: ‘,acc, ‘%’)
test_acc.append(acc)
if max(test_acc)==acc:
#torch.save(capsnet.state_dict(), ‘gdrive/My Drive/capsnet_40_epochs.pth’)
losses.append(l)
print(’\rEpoch #:’, epoch+1, ’ Training Loss: ',"{0:.2f}".format(l))
losses.append(l)
#torch.save(pcae.state_dict(), “./gdrive/My Drive/trained_pcae_2.pth”)
print(‘Finished’)

Are the previously mentioned solutions not working for you?
PS: you can post code snippets by wrapping them into three backticks ```, which would make debugging easier.