Pytorch lstm image classification dimension issue

I am trying to feed my own image dataset into PyTorch LSTM implementation that works fine with MNIST dataset. But getting " RuntimeError: Assertion `THIndexTensor_(size)(target, 0) == batch_size’ failed. at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/THNN/generic/ClassNLLCriterion.c:79" this error message.

import torch
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import as data

TRANSFORM_IMG = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225] )

TRAIN_DATA_PATH = “./train_set/”
TEST_DATA_PATH = “./test_set/”
train_data = torchvision.datasets.ImageFolder(root=TRAIN_DATA_PATH, transform=TRANSFORM_IMG)
train_data_loader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

test_data = torchvision.datasets.ImageFolder(root=TEST_DATA_PATH, transform=TRANSFORM_IMG)
test_data_loader = data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

labels = [‘low’,‘medium’,‘high’]

Hyper Parameters

sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 3

num_epochs = 2
learning_rate = 0.01

RNN Model (Many-to-One)

class RNN(nn.Module):
def init(self, input_size, hidden_size, num_layers, num_classes):
super(RNN, self).init()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)

def forward(self, x):
    # Set initial states 
    h0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)) 
    c0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))
    # Forward propagate RNN
    out, _ = self.lstm(x, (h0, c0))  
    # Decode hidden state of last time step
    out = self.fc(out[:, -1, :])  
    return out

rnn = RNN(input_size, hidden_size, num_layers, num_classes)

Loss and Optimizer

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)

Train the Model

for epoch in range(num_epochs):
for i, (images, labels) in enumerate(trainloader):
images = Variable(images.view(-1, sequence_length, input_size))
labels = Variable(labels)

    # Forward + Backward + Optimize
    outputs = rnn(images)
    loss = criterion(outputs, labels)
    if (i+1) % 100 == 0:
        print ('Epoch [%d/%d], Step [%d/%d], Loss: %.4f' 
               %(epoch+1, num_epochs, i+1, len(train_dataset)//batch_size,[0]))

So what are the shapes of output and labels?

Best regards


outputs.shape returns torch.Size([300,3])
labels.shape returns torch.Size([100])

Thanks for reply

Do you have RGB images perchance and expect only a single channel?
I’m asking because apparently images.view triples the shape[0]. You might want to try printing images.shape before and after the view.
A quick fix might be averaging over the channel dimension to eliminate the excess data.

Best regards


P.S.: When you post code, the triple backticks (```) can be used at the beginning and end of the code to preserve the code formatting.

1 Like

oh true you are right , my images are RGB !
Is there a way to convert all dataset to grayscale or how can I get average over channel dimension?

Thanks a lot for helping.

You could use the Grayscale transformation from torchvision.

1 Like

Thank you , I have fixed that error. But now I am getting following error ;

RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes’ failed. at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/THNN/generic/ClassNLLCriterion.c:87

Could you check the min and max of your targets?
Based on the output shape, it seems you have 3 classes.
Your targets should be [0, 1, 2].

It is [1,2,3] now , how can I set it to [0,1,2] ?

You could just subtract 1 from the target: target = target - 1.
How did you create them? You could try to create them in the right range.