Problem understanding input dimensions for NLLLoss()

I’m building a BiLSTM for multiclass classification in gait acceleration signals. Here’s my current code:

# -*- coding: utf-8 -*-

import numpy as np
import torch
from torch import nn
from torch import optim
import time
import matplotlib.pyplot as plt

class LSTM_model(nn.Module):
    
    def __init__(self, seq_length, input_dim, num_labels, hidden_dim=256, n_layers=2, drop_prob=0.5, bidirectional = True):      
        
        super().__init__()
        
        # Storing arguments as class attributes
        self.input_dim = input_dim
        self.num_labels = num_labels
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.drop_prob = drop_prob
        self.bidirectional = bidirectional
             
        # LSTM layer
        self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.n_layers, dropout=self.drop_prob, 
                            batch_first=True, bidirectional = self.bidirectional)
        
        # Dropout layer
        self.dropout = nn.Dropout(p=self.drop_prob)
        
        # Fully connected layer (size 2 * hidden_dim because it's bidirectional)
        self.fc = nn.Linear(self.hidden_dim * 2, self.num_labels)
        
        # LogSoftmax for classification
        self.logsoftmax = nn.LogSoftmax(dim = 1)
        
    def forward(self, x, h = None):
    
        if (h==None):
            # If no initial hidden_state and memory are provided, they are set to 0
            r_output, hidden = self.lstm(x)  
        
        else:
            r_output, hidden = self.lstm(x,h)     
        
        # Pass through a dropout layer
        out = self.dropout(r_output)
        
        # Put x through the fully-connected layer and a logsoftmax output to determine classes
        out = self.fc(out)
        out = self.logsoftmax(out)
        
        # Return the final output and the hidden state
        return out, hidden
    
    
class LSTM(LSTM_model):
    
    def __init__(self, batch_size, seq_length, epochs, input_dim, num_labels,  lr = 0.001, clip = 5, 
                 hidden_dim=256, n_layers=2, drop_prob=0.5, bidirectional = True):

        super().__init__(seq_length, input_dim, num_labels, hidden_dim, n_layers, drop_prob, bidirectional)
        
        # Storing arguments as class attributes
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.epochs = epochs
        self.lr = lr
        self.clip = clip
        self.num_labels = num_labels
        self.n_layers = n_layers
        
        # Defining the remaining attributes
        self.optim = optim.Adam(self.parameters(), self.lr)
        self.criterion = nn.NLLLoss()
        
        # Lists to store losses
        self.loss_during_training = []
        self.val_loss_during_training = []
        
        # Setting GPU training if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.to(self.device)
        
        # Training mode by default
        self.train()
        
        
    def get_batches(self, instance):
        """
        Generates batches of single time samples for the network

        Parameters
        ----------
        instance : numpy array

        Returns
        -------
        batches: list of numpy array

        """
         
        # Trim signal length to adjust to batch size
        total_batch_size = self.batch_size * self.seq_length
        total_batches = instance.shape[0] // total_batch_size
        instance = instance[:total_batches * total_batch_size, :]
        
        # Splitting trimmed array into batches
        batches = np.split(instance, total_batches)
        
        # 
        for batch in batches:
            
            batch =  batch.reshape(self.seq_length, self.batch_size, -1)
            
            x = batch[:, :, 0:4]
            y = batch[:, :, 4:8]
            
            yield x, y
            
            
    def trainloop(self, data):
        """
        Network training method

        Parameters
        ----------
        data : list of numpy array
            Data to be fed to the network.

        Returns
        -------
        None.

        """
        
        for e in range(self.epochs):
            
            # Batch counter
            counter = 0.
            
            # Storing current time
            start_time = time.time()
            
            running_loss = 0.
            
            # For each instance of data
            for instance in data:
                
                # Each instance has multiple batches
                for x, y in self.get_batches(instance):
                    
                    # Convert data to tensor
                    x, y = torch.from_numpy(x).float().to(self.device), torch.from_numpy(y).float().to(self.device)
                    
                    # Resetting gradients
                    self.optim.zero_grad()
                    
                    # Compute output
                    out, _ = self.forward(x)
                    
                    # Compute loss
                    loss = self.criterion(out, y)
                    
                    # Compute gradients
                    loss.backward()
                    
                    # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
                    nn.utils.clip_grad_norm_(self.parameters(), self.clip)
                    
                    self.optim.step() 
                    
                    # Storing current loss
                    running_loss += loss.item()
                    
                self.loss_during_training.append(running_loss/counter)
            
        if(e % 1 == 0): 
            
            print("Epoch %d. Training loss: %f, Time per epoch: %f seconds" 
                  %(e,self.loss_during_training[-1],(time.time() - start_time))) 

The input of the net is a Tensor with size (sequence length, batch size, input dimension). Sequence length is currently 250 and batch size is 20. Input dimension is always 4 since i have 4 different signals.

The output of the net is a Tensor with size (sequence length, batch size, number of classes), which are the log probabilities of all classes for each time instant. I have 4 classes.

The output size coincides with my target tensor size. However, when calling NLLLoss(out, y), the following error occurs:

ValueError: Expected target size (250, 4), got torch.Size([250, 20, 4])

Any help will be greatly appreciated.

1 Like

Hello Patxi!

This is your error. NLLLoss expects a target (your y) with one less
dimension that the input (your out) – that is, without the nClass
dimension.

In the so-called “K-dimensional case” with K = 1, NLLLoss expects
an input of shape [nBatch, nClass, d_1] and a target of shape
[nBatch, d_1]. target does not have an nClass dimension; instead,
the values of target are integer class labels that run from 0 to
nClass - 1.

I’m guessing that you want to pass in an input (your out) of shape
[nBatch = 250, nClass = 4, d_1 = nSeq = 20], and therefore
have a target with a shape of [250, 20] whose values are in the
range [0, 3] (inclusive).

Best.

K. Frank

Hi KFrank,

Thank you very much for your quick response! I was focused on both output and labels having the same dimension that I couldn’t see beyond that. Your answer was on point.

However, as I understood better how NLLLoss works I realized my problem works better as a Multilabel classification problem instead of a Multiclass one, since each time instant can have multiple classes and NLLLoss is strictly for multiclass problems where only one class can be assigned.

After searching a bit I found that BCEWithLogitsLoss can be a good option to adapt my network to a multilabel problem.

I hope this gives you or anyone that visits this post a bit more insight. I’ll update the post if it works for me. Thank you for your help!

Best,
Patxi.

1 Like