Inception_v3 is not working very well

Hey, I have implemented ResNet and Densenet in PyTorch. I am now using Inception V3. But when I was first using it throws me an error, that I solved by changing the -

transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),
                                     transforms.RandomResizedCrop(244), to 
                                     transforms.RandomResizedCrop(299)
                                     # transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                     ])

But after that, it is giving me this error -

AttributeError Traceback (most recent call last)
in ()
47
48 output = model.forward(images)
—> 49 loss = criterion(output, labels)
50 loss.backward()
51 optimizer.step()

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
487 result = self._slow_forward(*input, **kwargs)
488 else:
→ 489 result = self.forward(*input, **kwargs)
490 for hook in self._forward_hooks.values():
491 hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py in forward(self, input, target)
208 @weak_script_method
209 def forward(self, input, target):
→ 210 return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
211
212

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
1780 if size_average is not None or reduce is not None:
1781 reduction = _Reduction.legacy_get_string(size_average, reduce)
→ 1782 dim = input.dim()
1783 if dim < 2:
1784 raise ValueError(‘Expected 2 or more dimensions (got {})’.format(dim))

AttributeError: ‘tuple’ object has no attribute ‘dim’

Here is my code -

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = models.inception_v3(pretrained=True)

for param in model.parameters():
    param.requires_grad = True
    param.aux_logits=False

model.fc = nn.Sequential(nn.Linear(2048, 1024),
                                 nn.ReLU(),
                                 nn.Dropout(0.4),
                                 nn.Linear(1024,512),
                                 nn.ReLU(),
                                 nn.Dropout(0.4),
                                 nn.Linear(512,4),
                                 nn.LogSoftmax(dim=1))

criterion = nn.NLLLoss()

optimizer = optim.Adam(model.fc.parameters(), lr = 0.0001)

model.to(device);
                
epochs = 60
#steps = 0
#print_every = 5

for epoch in range(epochs):
  
  running_loss = 0
  model.train()
  for images, labels in dataloader_train:
    
    #steps += 1
    images, labels = images.to(device), labels.to(device)
    
    optimizer.zero_grad()
    
    output = model.forward(images)
    loss = criterion(output, labels)
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()
    
  #if steps % print_every == 0:
  valid_loss = 0
  accuracy = 0
  model.eval()
  for images, labels in dataloader_test:
    optimizer.zero_grad()
    with torch.no_grad():
       
      images, labels = images.to(device), labels.to(device)

      output = model.forward(images)
      loss = criterion(output, labels)
          
      valid_loss += loss.item()
          
      ps = torch.exp(output)
         
      top_p, top_class = ps.topk(1, dim = 1)
      equals = top_class == labels.view(*top_class.shape)
      accuracy += torch.mean(equals.type(torch.FloatTensor))

I have read many posts, but in no post can particularly answer my problem. Can any of you guys please resolve the issue? Thank you guys in advance.

In the default setup, you Inception model will output two values, the output from the last layer and the auxiliary logits.
If you don’t need the latter, create your model with aux_logits=False:

model = models.inception_v3(pretrained=True, aux_logits=False)

done that. it is the result -

RuntimeError: Error(s) in loading state_dict for Inception3:
Unexpected key(s) in state_dict: “AuxLogits.conv0.conv.weight”, “AuxLogits.conv0.bn.weight”, “AuxLogits.conv0.bn.bias”, “AuxLogits.conv0.bn.running_mean”, “AuxLogits.conv0.bn.running_var”, “AuxLogits.conv1.conv.weight”, “AuxLogits.conv1.bn.weight”, “AuxLogits.conv1.bn.bias”, “AuxLogits.conv1.bn.running_mean”, “AuxLogits.conv1.bn.running_var”, “AuxLogits.fc.weight”, “AuxLogits.fc.bias”.

Yeah, right, my bad. We had this issue already in the past.
In that case, just leave the instantiation of your models with aux_logits=True and just pass the first output of your model to the criterion:

outputs = model(data)
loss = criterion(outputs[0], target)
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
BrokenPipeError: [Errno 32] Broken pipe
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-20-42426169cbfd> in <module>()
     44     optimizer.zero_grad()
     45 
---> 46     output = model.forward(images)
     47     loss = criterion(output[0], labels)
     48     loss.backward()

/usr/local/lib/python3.6/dist-packages/torchvision/models/inception.py in forward(self, x)
     94         x = self.Mixed_5c(x)
     95         # 35 x 35 x 288
---> 96         x = self.Mixed_5d(x)
     97         # 35 x 35 x 288
     98         x = self.Mixed_6a(x)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    487             result = self._slow_forward(*input, **kwargs)
    488         else:
--> 489             result = self.forward(*input, **kwargs)
    490         for hook in self._forward_hooks.values():
    491             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/torchvision/models/inception.py in forward(self, x)
    144 
    145     def forward(self, x):
--> 146         branch1x1 = self.branch1x1(x)
    147 
    148         branch5x5 = self.branch5x5_1(x)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    487             result = self._slow_forward(*input, **kwargs)
    488         else:
--> 489             result = self.forward(*input, **kwargs)
    490         for hook in self._forward_hooks.values():
    491             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/torchvision/models/inception.py in forward(self, x)
    323 
    324     def forward(self, x):
--> 325         x = self.conv(x)
    326         x = self.bn(x)
    327         return F.relu(x, inplace=True)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    487             result = self._slow_forward(*input, **kwargs)
    488         else:
--> 489             result = self.forward(*input, **kwargs)
    490         for hook in self._forward_hooks.values():
    491             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in forward(self, input)
    318     def forward(self, input):
    319         return F.conv2d(input, self.weight, self.bias, self.stride,
--> 320                         self.padding, self.dilation, self.groups)
    321 
    322 

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in handler(signum, frame)
    272         # This following call uses `waitid` with WNOHANG from C side. Therefore,
    273         # Python can still get and update the process status successfully.
--> 274         _error_if_any_worker_fails()
    275         if previous_handler is not None:
    276             previous_handler(signum, frame)

RuntimeError: DataLoader worker (pid 6933) is killed by signal: Killed.

Might be unrelated to the first issue.
Could you set num_workers=0 and run the code again to get a proper error message?

> # choose the training and test datasets
> train_data = datasets.ImageFolder(data+"/train", transform=transform_train)
> test_data = datasets.ImageFolder(data+"/val", transform = transform_test)
> 
> dataloader_train = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=0)
> dataloader_test = torch.utils.data.DataLoader(test_data, batch_size=32, num_workers=0)

but just running and then vanish.

What do you mean by “vanish”?
Does you kernel just die or do you get any error message?
If you are running your script in a Jupyter notebook, try to export it as a .py file and run it in a terminal.
Sometimes the IPython kernel just dies without throwing the error message.

I am running it in google Colab. Yeah, every time I run it, just does not crash, but not giving me any results. It just die by it’s own.

In case you are using the GPU, could you try to run it on the CPU in Colab?
If that’s not the case, running it locally on your machine might give you an exception.

Well, I am sorry, I was using without GPU, now I am enabling the GPU. Let’s see.

Got this error -

ValueError                                Traceback (most recent call last)
<ipython-input-4-a22bbff7921a> in <module>()
     62 
     63       output = model.forward(images)
---> 64       loss = criterion(output[0], labels)
     65 
     66       valid_loss += loss.item()

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    487             result = self._slow_forward(*input, **kwargs)
    488         else:
--> 489             result = self.forward(*input, **kwargs)
    490         for hook in self._forward_hooks.values():
    491             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py in forward(self, input, target)
    208     @weak_script_method
    209     def forward(self, input, target):
--> 210         return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
    211 
    212 

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1782     dim = input.dim()
   1783     if dim < 2:
-> 1784         raise ValueError('Expected 2 or more dimensions (got {})'.format(dim))
   1785 
   1786     if input.size(0) != target.size(0):

ValueError: Expected 2 or more dimensions (got 1)

Blockquote

I removed aux_logits=True. Can you please tell what is the problem now?

Could you print the following for debugging porposes:

outputs = model(data)
print(len(outputs))
print(type(outputs))
print(outputs[0].shape)
print(labels.shape)

This is what I got -
2
<class ‘tuple’>
torch.Size([32, 4])
torch.Size([32])

I’m not sure, what’s going on, as outputs[0] has 2 dimensions.
Have a look at this small example and try to compare your code to it:

model = models.inception_v3(pretrained=True)
x = torch.randn(2, 3, 299, 299)
target = torch.randint(0, 1000, (2,))
criterion = nn.CrossEntropyLoss()

output = model(x)
loss = criterion(output[0], target)
loss.backward()

Hey, ultimately, I researched with this and found out that when I remove [0] from
loss = criterion(output[0], labels) in

for epoch in range(epochs):
  
  running_loss = 0
  model.train()
  for images, labels in dataloader_train:
    
    #steps += 1
    images, labels = images.to(device), labels.to(device)
    
    optimizer.zero_grad()
    
    output = model.forward(images)
    loss = criterion(output[0], labels)
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()
    
  #if steps % print_every == 0:
  valid_loss = 0
  accuracy = 0
  model.eval()
  for images, labels in dataloader_test:
    optimizer.zero_grad()
    with torch.no_grad():
       
      images, labels = images.to(device), labels.to(device)

      output = model.forward(images)
      loss = criterion(output, labels)
          
      valid_loss += loss.item()
          
      ps = torch.exp(output)

It suddenly worked. Can you please explain why this. You helped me a lot to succeed. Thanks.

I’m glad you figured it out!
During evaluation, i.e. after you call model.eval(), you won’t get the auxiliary output.
So you just have to pass outputs[0] while training and can pass outputs directly during evaluation.

3 Likes