IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) when generating predictions

I am trying to convert a keras CNN model I created with the functional API to PyTorch. As a sanity check, I made sure the output of my model summary (torchsummary) matched the keras model summary. However, when feeding my PyTorch model images in batches for training purposes, the error “IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)” is thrown. The keras model as well as my (flawed) PyTorch implementation looks like this:

Keras:

from tensorflow.keras import Input, layers
from tensorflow.keras.models import Model

image_size = 200

input_tensor = Input(shape = (image_size, image_size, 3))
x = layers.Conv2D(32, kernel_size = (3,3), strides = (1,1),
                                   activation = 'relu',
                                   input_shape = (image_size, image_size, 3),
                                   data_format="channels_last",
                                   padding = 'same')(input_tensor)
y = layers.MaxPool2D(pool_size = (2,2), strides = (2,2))(x)
x = layers.Conv2D(32, kernel_size = (3,3), strides = (1,1),
                                   activation = 'relu',
                                   input_shape = (image_size, image_size, 3),
                                   data_format="channels_last",
                                   padding = 'same')(y)
x = layers.Conv2D(32, kernel_size = (3,3), strides = (1,1),
                                   activation = 'relu',
                                   input_shape = (image_size, image_size, 3),
                                   data_format="channels_last",
                                   padding = 'same')(x)
x = layers.MaxPool2D(pool_size = (2,2), strides = (2,2))(x)
residual = layers.Conv2D(32, kernel_size = (1,1), strides = (2,2),
                                   activation = 'relu',
                                   input_shape = (image_size, image_size, 3),
                                   data_format="channels_last",
                                   padding = 'same')(y)
y = layers.add([residual,x])
x = layers.Conv2D(64, kernel_size = (3,3), strides = (1,1),
                                   activation = 'relu',
                                   input_shape = (image_size, image_size, 3),
                                   data_format="channels_last",
                                   padding = 'same')(y)
x = layers.Conv2D(64, kernel_size = (3,3), strides = (1,1),
                                   activation = 'relu',
                                   input_shape = (image_size, image_size, 3),
                                   data_format="channels_last",
                                   padding = 'same')(x)
x = layers.MaxPool2D(pool_size = (2,2), strides = (2,2))(x)
residual = layers.Conv2D(64, kernel_size = (1,1), strides = (2,2),
                                   input_shape = (image_size, image_size, 3),
                                   data_format="channels_last",
                                    padding = 'same')(y)
y = layers.add([residual,x])

x = layers.Flatten()(y)
x = layers.Dense(128, activation = 'relu')(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(64, activation = 'relu')(x)
x = layers.Dropout(0.5)(x)

output_tensor = layers.Dense(4, activation = 'softmax')(x)  

model = Model(input_tensor, output_tensor)

print(model.summary())

Model: "model_4"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_5 (InputLayer)            [(None, 200, 200, 3) 0                                            
__________________________________________________________________________________________________
conv2d_28 (Conv2D)              (None, 200, 200, 32) 896         input_5[0][0]                    
__________________________________________________________________________________________________
max_pooling2d_12 (MaxPooling2D) (None, 100, 100, 32) 0           conv2d_28[0][0]                  
__________________________________________________________________________________________________
conv2d_29 (Conv2D)              (None, 100, 100, 32) 9248        max_pooling2d_12[0][0]           
__________________________________________________________________________________________________
conv2d_30 (Conv2D)              (None, 100, 100, 32) 9248        conv2d_29[0][0]                  
__________________________________________________________________________________________________
conv2d_31 (Conv2D)              (None, 50, 50, 32)   1056        max_pooling2d_12[0][0]           
__________________________________________________________________________________________________
max_pooling2d_13 (MaxPooling2D) (None, 50, 50, 32)   0           conv2d_30[0][0]                  
__________________________________________________________________________________________________
add_8 (Add)                     (None, 50, 50, 32)   0           conv2d_31[0][0]                  
                                                                 max_pooling2d_13[0][0]           
__________________________________________________________________________________________________
conv2d_32 (Conv2D)              (None, 50, 50, 64)   18496       add_8[0][0]                      
__________________________________________________________________________________________________
conv2d_33 (Conv2D)              (None, 50, 50, 64)   36928       conv2d_32[0][0]                  
__________________________________________________________________________________________________
conv2d_34 (Conv2D)              (None, 25, 25, 64)   2112        add_8[0][0]                      
__________________________________________________________________________________________________
max_pooling2d_14 (MaxPooling2D) (None, 25, 25, 64)   0           conv2d_33[0][0]                  
__________________________________________________________________________________________________
add_9 (Add)                     (None, 25, 25, 64)   0           conv2d_34[0][0]                  
                                                                 max_pooling2d_14[0][0]           
__________________________________________________________________________________________________
flatten_4 (Flatten)             (None, 40000)        0           add_9[0][0]                      
__________________________________________________________________________________________________
dense_12 (Dense)                (None, 128)          5120128     flatten_4[0][0]                  
__________________________________________________________________________________________________
dropout_4 (Dropout)             (None, 128)          0           dense_12[0][0]                   
__________________________________________________________________________________________________
dense_13 (Dense)                (None, 64)           8256        dropout_4[0][0]                  
__________________________________________________________________________________________________
dropout_5 (Dropout)             (None, 64)           0           dense_13[0][0]                   
__________________________________________________________________________________________________
dense_14 (Dense)                (None, 4)            260         dropout_5[0][0]                  
==================================================================================================
Total params: 5,206,628
Trainable params: 5,206,628
Non-trainable params: 0

PyTorch:

import torch
import torch.nn as nn

from torchsummary import summary

class ResNet(nn.Module):

    def __init__(self, num_classes=4):
        super().__init__()
        
        self.conv32_1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding='same',
                               bias=False)
        
        self.conv32_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding='same',
                               bias=False)
        
        self.relu = nn.ReLU(inplace=True)
        
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv64_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding='same',
                               bias=False)
        
        self.conv64_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding='same',
                               bias=False)
        
        self.conv32resid = nn.Conv2d(32, 32, kernel_size=1, stride=2, #padding='same',
                               bias=False)
        
        self.conv64resid = nn.Conv2d(32, 64, kernel_size=1, stride=2, #padding='same',
                               bias=False)
        
        self.avg = nn.AdaptiveAvgPool2d((1, 1))

        self.fc1 = nn.Linear(40000 , 128)
        self.fc2 = nn.Linear(128, 64)
        self.out = nn.Linear(64, num_classes)
    
    
    def forward(self, x):
        x = self.conv32_1(x)
        x = self.relu(x)
        y = self.maxpool(x)
        x = self.conv32_2(y)
        x = self.relu(x)
        x = self.conv32_2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        residual = self.conv32resid(y)
        residual = self.relu(residual)
        y = torch.cat((residual, x))
        
        x = self.conv64_1(y)
        x = self.relu(x)
        x = self.conv64_2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        residual = self.conv64resid(y)
        residual = self.relu(residual)
        y = torch.cat((residual, x))
           
        x = torch.flatten(y)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.out(x)

        return x

model = ResNet()

summary(model.cuda(), (3, 200, 200))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 32, 200, 200]             864
              ReLU-2         [-1, 32, 200, 200]               0
         MaxPool2d-3         [-1, 32, 100, 100]               0
            Conv2d-4         [-1, 32, 100, 100]           9,216
              ReLU-5         [-1, 32, 100, 100]               0
            Conv2d-6         [-1, 32, 100, 100]           9,216
              ReLU-7         [-1, 32, 100, 100]               0
         MaxPool2d-8           [-1, 32, 50, 50]               0
            Conv2d-9           [-1, 32, 50, 50]           1,024
             ReLU-10           [-1, 32, 50, 50]               0
           Conv2d-11           [-1, 64, 50, 50]          18,432
             ReLU-12           [-1, 64, 50, 50]               0
           Conv2d-13           [-1, 64, 50, 50]          36,864
             ReLU-14           [-1, 64, 50, 50]               0
        MaxPool2d-15           [-1, 64, 25, 25]               0
           Conv2d-16           [-1, 64, 25, 25]           2,048
             ReLU-17           [-1, 64, 25, 25]               0
           Linear-18                       [-1]       5,120,128
             ReLU-19                       [-1]               0
           Linear-20                       [-1]           8,256
             ReLU-21                       [-1]               0
           Linear-22                       [-1]             260
================================================================
Total params: 5,206,308
Trainable params: 5,206,308
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.46
Forward/backward pass size (MB): 39.37
Params size (MB): 19.86
Estimated Total Size (MB): 59.69
----------------------------------------------------------------

The code for training is the following:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torch.utils.data import DataLoader
import numpy as np

image_size = 200
batch_size = 8

train_dataset = datasets.ImageFolder(train_dir, transform=transforms_train)
        
val_dataset = datasets.ImageFolder(val_dir, transform=transform_val)

test_dataset = datasets.ImageFolder(test_dir, transform=transform_val)

train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True, 
        drop_last=True, 
        pin_memory=True)

val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size)

test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

epochs = 20
min_valid_loss = np.inf
loss_values_train = []
loss_values_val = []

acc_values_train = []
acc_values_val = []
for e in range(epochs):
    y_pred_train = []
    y_true_train = []

    train_loss = 0.0
    model.train()     
    for data, labels in train_loader:
        if torch.cuda.is_available():
            data, labels = data.cuda(), labels.cuda()
        
        optimizer.zero_grad()
        target = model(data)
        loss = criterion(target,labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * data.size(0) 
        
        _, preds = torch.max(target, 1)
        y_pred_train.append(preds.tolist())
        y_true_train.append(labels.tolist())

    y_pred_train = [item for sublist in y_pred_train for item in sublist]
    y_true_train = [item for sublist in y_true_train for item in sublist]
    train_acc = accuracy_score(y_true_train, y_pred_train)
    acc_values_train.append(train_acc)
    loss_values_train.append(train_loss / len(train_loader))
    
    y_pred_val = []
    y_true_val = []
    
    valid_loss = 0.0
    model.eval()     
    for data, labels in val_loader:
        if torch.cuda.is_available():
            data, labels = data.cuda(), labels.cuda()
        
        target = model(data)
        loss = criterion(target,labels)
        valid_loss += loss.item() * data.size(0)
        
        _, preds = torch.max(target, 1)
        y_pred_val.append(preds.tolist())
        y_true_val.append(labels.tolist())
        
    y_pred_val = [item for sublist in y_pred_val for item in sublist]
    y_true_val = [item for sublist in y_true_val for item in sublist]
    val_acc = accuracy_score(y_true_val, y_pred_val)
    acc_values_val.append(val_acc)
    loss_values_val.append(valid_loss / len(val_loader))
    
    print(f'Epoch {e+1}: \n Training Loss: {train_loss/len(train_loader)} \t Training Acc: {train_acc} \n Validation Loss: {valid_loss/len(val_loader)} \t Validation Acc: {val_acc}')
    if min_valid_loss > valid_loss:
        min_valid_loss = valid_loss
        # Saving State Dict
        torch.save(model.state_dict(), 'resnet_model.pth')

My guess is that the flatten operation at the end of my model is somehow causing problems. Any help would be greatly appreciated!

This flattens your y array completely, removing the first dimension, which is batch_size. Use x = torch.flatten(y, start_dim=1) to keep the batches dimension (as per flatten ). The -1 you see in your summary is supposed to represente the batch_size (same thing as None in your keras model summary). Instead of having [-1, 40000] which is equivalent to (None, 40000), you got only [-1].

1 Like

Thanks a lot, @Youyoun, your solution resolved the IndexError.
However, I now get a different error when trying to train the model:

epochs = 20
min_valid_loss = np.inf
loss_values_train = []
loss_values_val = []

acc_values_train = []
acc_values_val = []
for e in range(epochs):
    y_pred_train = []
    y_true_train = []

    train_loss = 0.0
    model.train()     # Optional when not using Model Specific layer
    for data, labels in train_loader:
        if torch.cuda.is_available():
            data, labels = data.cuda(), labels.cuda()
        
        optimizer.zero_grad()
        target = model(data)
        print(target)
        loss = criterion(target,labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * data.size(0) 
        
        _, preds = torch.max(target, 1)
        y_pred_train.append(preds.tolist())
        y_true_train.append(labels.tolist())

    y_pred_train = [item for sublist in y_pred_train for item in sublist]
    y_true_train = [item for sublist in y_true_train for item in sublist]
    train_acc = accuracy_score(y_true_train, y_pred_train)
    acc_values_train.append(train_acc)
    loss_values_train.append(train_loss / len(train_loader))
    
    y_pred_val = []
    y_true_val = []
    
    valid_loss = 0.0
    model.eval()     # Optional when not using Model Specific layer
    for data, labels in val_loader:
        if torch.cuda.is_available():
            data, labels = data.cuda(), labels.cuda()
        
        target = model(data)
        loss = criterion(target,labels)
        valid_loss += loss.item() * data.size(0)
        
        _, preds = torch.max(target, 1)
        y_pred_val.append(preds.tolist())
        y_true_val.append(labels.tolist())
        
    y_pred_val = [item for sublist in y_pred_val for item in sublist]
    y_true_val = [item for sublist in y_true_val for item in sublist]
    val_acc = accuracy_score(y_true_val, y_pred_val)
    acc_values_val.append(val_acc)
    loss_values_val.append(valid_loss / len(val_loader))
    
    print(f'Epoch {e+1}: \n Training Loss: {train_loss/len(train_loader)} \t Training Acc: {train_acc} \n Validation Loss: {valid_loss/len(val_loader)} \t Validation Acc: {val_acc}')
    if min_valid_loss > valid_loss:
        min_valid_loss = valid_loss
        # Saving State Dict
        torch.save(model.state_dict(), 'resnet_model.pth')

---------------------------------------------------------------------------

tensor([[ 0.1182,  0.1314,  0.0728, -0.0236],
        [ 0.1169,  0.1362,  0.0707, -0.0107],
        [ 0.1232,  0.1335,  0.0721, -0.0102],
        [ 0.1210,  0.1218,  0.0656, -0.0114],
        [ 0.1180,  0.1376,  0.0793, -0.0139],
        [ 0.1136,  0.1300,  0.0782, -0.0301],
        [ 0.1139,  0.1294,  0.0726, -0.0256],
        [ 0.1209,  0.1298,  0.0703, -0.0201],
        [ 0.1202,  0.1272,  0.0728, -0.0251],
        [ 0.1221,  0.1278,  0.0727, -0.0245],
        [ 0.1231,  0.1286,  0.0757, -0.0226],
        [ 0.1216,  0.1305,  0.0756, -0.0254],
        [ 0.1219,  0.1298,  0.0752, -0.0226],
        [ 0.1213,  0.1308,  0.0761, -0.0256],
        [ 0.1234,  0.1294,  0.0742, -0.0261],
        [ 0.1210,  0.1275,  0.0736, -0.0225],
        [ 0.1211,  0.1280,  0.0748, -0.0246],
        [ 0.1237,  0.1290,  0.0737, -0.0197],
        [ 0.1225,  0.1302,  0.0763, -0.0223],
        [ 0.1222,  0.1330,  0.0765, -0.0226],
        [ 0.1231,  0.1316,  0.0746, -0.0217],
        [ 0.1209,  0.1325,  0.0772, -0.0254],
        [ 0.1235,  0.1270,  0.0750, -0.0191],
        [ 0.1219,  0.1286,  0.0753, -0.0236],
        [ 0.1199,  0.1309,  0.0744, -0.0239],
        [ 0.1213,  0.1314,  0.0740, -0.0223],
        [ 0.1215,  0.1310,  0.0742, -0.0223],
        [ 0.1214,  0.1317,  0.0738, -0.0223],
        [ 0.1224,  0.1317,  0.0734, -0.0223],
        [ 0.1211,  0.1314,  0.0740, -0.0222],
        [ 0.1206,  0.1305,  0.0748, -0.0225],
        [ 0.1213,  0.1313,  0.0745, -0.0228]], device='cuda:0',
       grad_fn=<AddmmBackward>)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-14-f023976bdd04> in <module>
     19         target = model(data)
     20         print(target)
---> 21         loss = criterion(target,labels)
     22         loss.backward()
     23         optimizer.step()

~\miniconda3\envs\glaucoma\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~\miniconda3\envs\glaucoma\lib\site-packages\torch\nn\modules\loss.py in forward(self, input, target)
   1119     def forward(self, input: Tensor, target: Tensor) -> Tensor:
   1120         return F.cross_entropy(input, target, weight=self.weight,
-> 1121                                ignore_index=self.ignore_index, reduction=self.reduction)
   1122 
   1123 

~\miniconda3\envs\glaucoma\lib\site-packages\torch\nn\functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2822     if size_average is not None or reduce is not None:
   2823         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2824     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   2825 
   2826 

ValueError: Expected input batch_size (32) to match target batch_size (8).

When printing out the shape of the tensor after the flatten operation, I get the following:

x = torch.flatten(y, start_dim=1)
print(x.shape)

---------------------------------------------------------------------------

torch.Size([8, 40000])

Seems to me, as if there still is an issue with dimensionality of my tensor.

Nevermind, I solved it.
The issue was with how I added the tensor together.
Instead of

y = torch.cat((residual, x))

simply changing to

y = residual + x

resolved the error.

The model is now running, thanks again!