Expected input batch_size (480) to match target batch_size (32)

Hello, I’m beginner to pytorch, trying to solve a text multi classification problem with Pytorch. Here is the code for my model which has 2-layer LSTM.

class netRNN(nn.Module):
    def __init__(self, vocab_size, output_size, embedding_dim, hidden_dim, n_layers, drop_prob=0.5):
        super(SentimentRNN, self).__init__()

        self.output_size = output_size
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        
        # embedding and LSTM layers
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, 
                            dropout=drop_prob, batch_first=True)
        
        # dropout layer
        self.dropout = nn.Dropout(0.2)
        
        # linear layer
        self.fc = nn.Linear(hidden_dim, output_size)


    def forward(self, x, hidden):
        batch_size = x.size(0)

        # embeddings and lstm_out
        x = x.long()
        embeds = self.embedding(x)
        lstm_out, hidden = self.lstm(embeds, hidden)
    
        # stack up lstm outputs
        lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim)
        
        # dropout and fully-connected layer
        out = self.dropout(lstm_out)
        out = self.fc(out)
        
        return out, hidden
    
    
    def init_hidden(self, batch_size):
        # Create two new tensors with sizes n_layers x batch_size x hidden_dim,
        # initialized to zero, for hidden state and cell state of LSTM
        weight = next(self.parameters()).data
        
        if (train_on_gpu):
            hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().cuda(),
                  weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().cuda())
        else:
            hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_(),
                      weight.new(self.n_layers, batch_size, self.hidden_dim).zero_())
        
        return hidden

These are my parameters which I used in the model:

vocab_size = len(vocab_to_int)+1 #I have around 5085 words
output_size = 4 #number of labels
embedding_dim = 100
hidden_dim = 50
n_layers = 2

lr=0.001

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

Here is the training code which throws the error:

epochs = 30
counter = 0
print_every = 100
clip=5 # gradient clipping

# move model to GPU, if available
if(train_on_gpu):
    net.cuda()

net.train()
# train for some number of epochs
for e in range(epochs):
    # initialize hidden state
    h = net.init_hidden(batch_size)

    # batch loop
    for inputs, labels in train_loader:
        inputs = inputs.long()
        labels = labels.long()

        counter += 1

        if(train_on_gpu):
            inputs, labels = inputs.cuda(), labels.cuda()

        # Creating new variables for the hidden state, otherwise
        # we'd backprop through the entire training history
        h = tuple([each.data for each in h])

        # zero accumulated gradients
        net.zero_grad()

        # get the output from the model
        output, h = net(inputs, h)
        print(output.shape)

        # calculate the loss and perform backprop
        loss = criterion(output.squeeze(), labels)
        loss.backward()
        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        nn.utils.clip_grad_norm_(net.parameters(), clip)
        optimizer.step()

        # loss stats
        if counter % print_every == 0:
            # Get validation loss
            val_h = net.init_hidden(batch_size)
            val_losses = []
            net.eval()
            for inputs, labels in valid_loader:

                inputs = inputs.long()
                labels = labels.long()
                # Creating new variables for the hidden state, otherwise
                # we'd backprop through the entire training history
                val_h = tuple([each.data for each in val_h])

                if(train_on_gpu):
                    inputs, labels = inputs.cuda(), labels.cuda()
                
                #if( (inputs.shape[0],inputs.shape[1]) != (batch_size,seq_length)):
                 # print("Validation - Input Shape Issue:",inputs.shape)
                  #continue

                output, val_h = net(inputs, val_h)
                val_loss = criterion(output.squeeze(), labels)

                val_losses.append(val_loss.item())

            net.train()
            print("Epoch: {}/{}...".format(e+1, epochs),
                  "Step: {}...".format(counter),
                  "Loss: {:.6f}...".format(loss.item()),
                  "Val Loss: {:.6f}".format(np.mean(val_losses)))

Here is the complete error message:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-130-585e071419a0> in <module>()
     48 
     49         # calculate the loss and perform backprop
---> 50         loss = criterion(output.squeeze(), labels)
     51         loss.backward()
     52         # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.

3 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   2111     if input.size(0) != target.size(0):
   2112         raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 2113                          .format(input.size(0), target.size(0)))
   2114     if dim == 2:
   2115         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

ValueError: Expected input batch_size (480) to match target batch_size (32).

Could someone please help me figure out the issue?

Can you tell what is the size of output and labels before it goes inside the criterion.?

Yes sure, these are the sizes:
input size = torch.Size([32, 15])
output size = torch.Size([480, 4])
labels size = torch.Size([32])

If labels is of size [32], then output must be of size [32,num_classes] inorder to agree with nn.CrossEntropyLoss()

Yes I’m aware of that, but I’m not sure how I can fix my output size so it becomes 32x4 instead of 480x4

Can you print what is the size before and after
lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim) in the forward() function ?

Sure, here are the sizes:

before: torch.Size([32, 15, 50])
after: torch.Size([480, 50])

See, the issue is here.
Change hidden_dim to 750 instead of 50 so that self.hidden will be 750 in
lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim).
After .view() , the size changes to [480,50], instead it must change to [32,750].
This way, the batch_size of 32 is maintained.

I ended up adding these two statements after my linear layer to get final size of output to be 32x4

out = out.view(batch_size, -1, self.output_size)
out = out[:, -1]

Thanks for your help and clarification.

1 Like