Using Sigmoid Function in CNN Classifier

Hey @ everybody,

I´m trying to use a CNN Classifier on clinical data (input shape 39,12, rows for values and columns for time intervals) to predict a categorial statement as good / bad result; so i´m using conv2d with the size (1,3) or (1,5) to find time sensitive patterns (each row == 1 clinical parameter) and a max pool pattern of (1,2).

After running the forward path i´m using a sigmoid function on the output node of the last linear network layer to receive a propability between 0 and 1 for category good; using these propabilties to discriminate the loss between them and the labels [0,1] and backpropagate on them.

I ask myself if the architecture ist correct and if it is whatelse can i do to improve performance, because los isnt really decreasing and accuracy is random at the begining of training and ends up getting worse.

# CNN Model
# Output Size of Convolutional Layer by ((Width - FilterSize + (Padding * 2)) / Stride) + 1

import torch.nn as nn
import torch.nn.functional as F


class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, (1,5), padding = (0,1), dilation = 1).double()
        # could use avg_pool2d 
        self.pool = nn.MaxPool2d((1,2), 1)
        self.conv2 = nn.Conv2d(6, 16, (1,3), padding = (0,1), dilation = 1).double()
        self.fc1 = nn.Linear(16 * 39 * 8, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.float()
        x = x.view(-1,16*39*8)
        # or use x = torch.flatten(x, 1) to flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        # Propabilities through Sigmoid Function
        x = torch.sigmoid(x)
        return x
    
    def binary(self,y):
        # adding function of sigmoid function and round to classify and calculate accuracy; KODIETZ 221017
        y = torch.round(y)
        return y
    
# Define model
net = ConvNet()

import torch.optim as optim
# define optimizer & Loss Function; get model
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)

# training
total_datapoints = len(trainloader)

for epoch in range(num_epochs):
    running_loss = 0.0
    epoch_acc = 0.0
    epoch_loss = 0.0
    print(f'-----------Start Epoch {epoch + 1} -----------------------')
    for i,(input_torch,labels) in enumerate(trainloader):
        
        # origin shape: [batch,channels,height,weight]
        # here we got [5,1,12,39] = 5,1,468
        input_torch = input_torch.double().to(device)
        labels = labels.float().to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = net(input_torch).reshape(1,len(input_torch))[0]
        loss = criterion(outputs, labels)

        # Accuray for Batch
        y_pred = net.binary(outputs)
        correct = sum(labels == y_pred).item()
        acc = correct / len(input_torch)
        epoch_acc += acc
        
        # Backward Propagation and Optimization
        loss.backward()
        optimizer.step()
        
        # print statistics
        epoch_loss += loss.item()
        if i % 5 == 0:
            print(f'{epoch + 1}, {i + 1}\t|\t loss: {loss.item():.5f} \t| acc: {acc}')
            
    print(f'-----------End Epoch {epoch + 1}: | Loss: {epoch_loss/len(trainloader):.5f} | Acc: {epoch_acc/len(trainloader):.3f}')

thanks for any help!
best regards

Hi Obrey!

Short story: Get rid of the sigmoid() and use BCEWithLogitsLoss.

This is a binary classification problem. You could treat it as a two-class
multi-class problem (and use CrossEntropyLoss), but you’re better off
using pytorch’s support for the “special case” of binary classification (and
use BCEWithLogitsLoss).

You have a mismatch here between your final layer, fc3, and your loss
criterion CrossEntropyLoss. If you wanted to treat your problem as
a two-class multi-class problem (but you don’t really), you would want
your final layer to have two outputs (e.g., Linear (84, 2)). As it stands,
your model won’t work.

Leave fc3 as it is, get rid of sigmoid(), and use BCEWithLogitsLoss.

(You could keep the sigmoid() and use BCELoss, but doing so would
be less numerically stable and would have no offsetting advantages.)

Best.

K. Frank

Thank you very much @KFrank

I came from a 2 class calssifier and forgot to change the loss criterion outch

Anyhow the classifier is still not satisfying :frowning: - because of an imabalanced dataset i provided a weigth attribute for the loss; but still not performing properly;
while training loss is accurately decreasing the accuracy is increasing;

any suggestions for proper tuning ?

best, obrey

Hi Obrey!

Just to confirm, it you are using BCEWithLogitsLoss, you should be
using its pos_weight constructor argument to increase the weight of
the positive samples if they are underrepresented (or decrease their
weight if they are overrepresented).

Could you clarify this? “Training loss is decreasing” – good, this is what
training is supposed to do. “The accuracy is increasing.” To me, “accuracy”
means the percentage of predictions that are correct. So this would also
be good.

Did you mean that the accuracy was decreasing? Were you contrasting
training vs. test (or validation) loss?

Best.

K. Frank