Multi Class Classification with nn.CrossEntropyLoss

I am getting decreasing loss as well as accuracy. The accuracy is 12-15% with CrossEntropyLoss. The same network except with a softmax for the last layer and loss as MSELoss, I am getting 96+% accuracy. I really want to know what I am doing wrong with CrossEntropyLoss. Here is my code:

class Conv1DModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv1d(8, 16, kernel_size=8)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=8)
        
        self.bn1 = nn.BatchNorm1d(32)
        
        self.conv3 = nn.Conv1d(32, 64, kernel_size=8)
        self.conv4 = nn.Conv1d(64, 128, kernel_size=8)
        
        self.bn2 = nn.BatchNorm1d(128)
        
        self.flat = nn.Flatten()
        
        self.fc1 = nn.Linear(128*14, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 64)
        self.fc4 = nn.Linear(64, classes)
        
    def exec_conv_block(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool1d(F.relu(self.conv2(x)), 2)
        x = self.bn1(x)
        x = F.relu(self.conv3(x))
        x = F.max_pool1d(F.relu(self.conv4(x)), 2)
        x = self.bn2(x)
        
        return x
    
    def forward(self, x):
        x = self.exec_conv_block(x)
        x = self.flat(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        
        return x
model = Conv1DModel().to(device)
opt = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()
def fwd_pass(X, y, train=False):
    if not train:
        with torch.no_grad():
            outputs = model(X)
            matches = [torch.argmax(i) == torch.argmax(j) for i, j in zip(outputs, y)]
    
            acc = matches.count(True)/len(matches)
    
            loss = loss_fn(outputs, y)
    else:
        outputs = model(X)
        matches = [torch.argmax(i) == torch.argmax(j) for i, j in zip(outputs, y)]

        acc = matches.count(True)/len(matches)

        loss = loss_fn(outputs, y)
    
    if train:
        loss.backward()
        opt.step()
        model.zero_grad()
        
    return acc, loss
def train(net, epochs, batch_size, X, y, val_X=None, val_y=None):
    accuracies = []
    losses = []
    val_accuracies = []
    val_losses = []
    
    for ep in tqdm(range(epochs)):
        for i in tqdm(range(0, len(X), batch_size)):
            batch_X = X[i:i+batch_size].to(device)
            batch_y = y[i:i+batch_size].to(device)
            
            acc, loss = fwd_pass(batch_X, batch_y, train=True)
            
            torch.cuda.empty_cache()
            
        if val_X != None and val_y != None:
            val_acc, val_loss = fwd_pass(val_X.to(device), val_y.to(device))
        
        print(f'Epoch: {ep+1}\nAcc: {round(float(acc), 3)} Loss: {round(float(loss), 4)}')
        
        if val_X != None and val_y != None:
            print(f'Val Acc: {round(float(val_acc), 3)} Val Loss: {round(float(val_loss), 4)}')
        
        accuracies.append(acc)
        losses.append(loss)
        
        if val_X != None and val_y != None:
            val_accuracies.append(val_acc)
            val_losses.append(val_loss)
            
    return accuracies, losses, val_accuracies, val_losses
EPOCHS = 10

acc, loss, val_acc, val_loss = train(model, EPOCHS, batch_size=128, X=X_train.transpose(1, 2), y=y_train, val_X=X_val.transpose(1, 2), val_y=y_val)

accuracies.extend(acc)
losses.extend(loss)
if val_acc != None and val_loss != None:
    val_accuracies.extend(val_acc)
    val_losses.extend(val_loss)
total_epochs += EPOCHS

Hi Kaustubh!

These two lines of code are in conflict with one another.

Your loss_fn, CrossEntropyLoss, expects its outputs argument to
have shape [nBatch, nClass], and its y argument to have shape
[nBatch] (no class dimension).

On the other hand, your torch.argmax(i) == torch.argmax(j)
test suggests that outputs and y have the same shape. I’m guessing
that your y has shape [nBatch] (because you don’t report that
CrossEntropyLoss has thrown an error), and therefore that your
matches test is incorrect.

What you want is for outputs to have shape [nBatch, nClass] (and
be the predicted logits for the nClass classes), and for y to have
shape [nBatch] and be integer class labels for each of the samples
in the batch. Then you want your test to be:

matches = (torch.argmax (outputs, dim = 1) == y).sum()

matches will now be the number of correct predictions in your batch.

Best.

K. Frank

1 Like

Thank you so much for this! I am completely new to PyTorch so I knew I was doing something silly. Just did not know what it was. This worked out just fine!

I was wondering if you could tell me why it showed my accuracy going up with MSE loss when I had one-hot labels?

Thank you again.