Accuracy not changing after second training epoch

Hi there. Kinda new to PyTorch. I have implemented a fully connected neural network and am trying to compute accuracy on the training samples. After the first epoch, I get a certain accuracy, after the second epoch, i get a slightly better accuracy, great. But then accuracy doesn’t change. Can anyone point me to the bug please? Note that I am not yet testing. I am still debugging my training algorithm

#HINT: note that your training time should not take many days.

max_epoch = 100
train_batch = 4
test_batch = 500
learning_rate = 0.1
use_gpu = torch.cuda.is_available()

def main(): # you are free to change parameters
    
    xl_file = 'labels.xlsx'
    root_dir = 'C:\\Users\\t.anjary\\Desktop\\ML\\ML HW3\\images'
    train_set, val_set, test_set = get_dataset(xl_file, root_dir)
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=train_batch, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=test_batch, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=test_batch, shuffle=False)
    
    model = FNet().float() #.cuda()
    criterion = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)#, momentum=0.9, weight_decay=5e-04) # you can play with momentum and weight_decay parameters as well
    
    for epoch in range(max_epoch):
        print('\n epoch '+ str(epoch))
        train(epoch, model, criterion, optimizer, train_loader)
    
''' Train your network for a one epoch '''
def train(epoch, model, criterion, optimizer, loader): # you are free to change parameters
    #print('train is called')
    accuracies =0
    model.train()
    correct = 0
    total = 0
    for batch_idx, (data, labels) in enumerate(loader):
        # TODO:
        # Implement training code for a one iteration
        
        #forward
        optimizer.zero_grad()
        data = data.float() #.cuda()
        labels = labels.float() #.cuda()
        y_pred = model(data)
        
        #loss
        loss = criterion(y_pred, labels)
        
        #update
        loss.backward()
        optimizer.step()
        
        #performance
        #y_pred = y_pred.cpu().numpy()
        y_pred = y_pred.flatten()
        total = labels.size(0)
        correct = (y_pred == labels).sum().item()
        accuracy = 100 * correct / total
        accuracies += accuracy 
        
        #info = 'iteration: ' + str(batch_idx) + '  loss: ' + str(loss.item()) + '  accuracy: ' + str(accuracy)
        #print(info)
        #sys.stdout.flush()
        sys.stdout.write('\r iteration: ' + str(batch_idx) + '  loss: ' + str(loss.item()) + '  accuracy: ' + str(accuracy))
        
        
    accuracies = accuracies/(batch_idx+1)
    print('\n training accuracy '+str(accuracies))

''' Test&Validate your network '''
def test(model, loader,criterion): # you are free to change parameters

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, labels) in enumerate(loader):
            data = data.float() #.cuda()
            labels = labels.float() #.cuda()
            
            #forward only
            y_pred = model(data)

            #loss
            loss = criterion(y_pred, labels)
        
            #assessy_pred = y_pred.flatten()
            y_pred = y_pred.flatten()
            total = labels.size(0)
            correct = (y_pred == labels).sum().item()
            acc = 100 * correct / total
            
            print('validation accuracy '+str(acc))
            return acc
        
        
main()

Hi Taher!

The short answer is that this line:

is a mistake because it is performing an exact-equality test on
floating-point numbers. (In general, doing so is a programming
bug except in certain special circumstances.)

(Note, this doesn’t affect your loss function, so your training
could be working.)

You don’t mention what sort of problem you’re working on or
what your data or model looks like, so we can’t really proceed
without some speculation. But if speculation is what you need,
you’ve come to the right place!

I speculate that you are working with images. Based on this, I
speculate that you are working on an object classification or
object detection problem.

There is something floating around out there called “FNet” that is
used for face detection. So maybe you are working on object
detection.

Your loss could be the mean-squared-error between the predicted
locations of objects detected by your object detector, and their
known locations as given in your annotated dataset. This could
make sense.

Based on this, both your y_pred and labels are floats.

This is a problem. Here you are performing an exact-equality test
of floating-point numbers. Doing so is (almost) always a mistake.

Even if your labels are nicely discrete values (say, pixel locations)
such as 1.00 or 503.00, for which an exact-equality test could
make sense, your y_pred are not. So the exact-equality test will
almost always fail, and the accuracy you calculate will always be
very low.

(Note, if you try to force your y_pred to be discrete, for example,
by rounding them to the nearest integral pixel location, you will
have performed a non-differentiable operation on them, so you
won’t be able to back-propagate usefully.)

As to whether the rest of your code makes sense, it’s hard to
say without knowing what you are trying to do.

Good luck.

K. Frank

Hi there. My apologies, I will give more details. The task is to perform binary classification on Ocular Disease Recognition, ODIR5k, dataset. Images are 128x256 rgb images. I process them into grayscale and flatten them in my custom dataset class:

class OcularDataset(Dataset):
    
    # TODO:
    # Define constructor for AnimalDataset class
    # HINT: You can pass processed data samples and their ground truth values as parameters 
    def __init__(self, xl_file, root_dir):
        #print('init was called')
        self.index_labels = pd.read_excel(xl_file)
        self.root_dir = root_dir
        self.n_samples = len(self.index_labels)
        
    '''This function should return sample count in the dataset'''
    def __len__(self):
        return self.n_samples

    '''This function should return a single sample and its ground truth value from the dataset corresponding to index parameter '''
    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir,str(self.index_labels.iloc[index,0])+'.jpg')
        img = mpimg.imread(img_path)
        rgb_weights = [0.2989, 0.5870, 0.1140]
        img = np.dot(img, rgb_weights)/255
        img = torch.tensor(img)
        img = img.reshape([1,-1])
        y = torch.tensor(int(self.index_labels.iloc[index,1]))
        
        return img, y


def get_dataset(xl_file, root_dir):
    fdataset = OcularDataset(xl_file, root_dir)
    
    train_set, val_set, test_set = torch.utils.data.random_split(fdataset, [2500,500,500])
    
    return train_set, val_set, test_set

Sample original image:

406

After the custom dataset, I implement a neural network class:

class FNet(nn.Module):
    
    def __init__(self): 
        super(FNet, self).__init__()
        self.linear1 = torch.nn.Linear(128*256, 1024)
        self.linear2 = torch.nn.Linear(1024, 256)
        self.linear3 = torch.nn.Linear(256, 1)
     
    def forward(self, X): 
        z1 = torch.nn.functional.relu(self.linear1(X))
        z2 = torch.nn.functional.relu(self.linear2(z1))
        y = torch.sigmoid(self.linear3(z2))
        
        return y

And then comes the code I had attached earlier, but with a few changes based on K.Franks insight on doing equality comparisons between predictions and labels. If you’re wondering why I convert everything to float, it’s because it solved other errors I was getting.

#HINT: note that your training time should not take many days.

max_epoch = 100
train_batch = 25
test_batch = 500


learning_rate = 0.0001

use_gpu = torch.cuda.is_available()


def main(): # you are free to change parameters
    
    xl_file = 'labels.xlsx'
    root_dir = 'C:\\Users\\t.anjary\\Desktop\\ML\\ML HW3\\images'
    train_set, val_set, test_set = get_dataset(xl_file, root_dir)
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=train_batch, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=test_batch, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=test_batch, shuffle=False)
    
    model = FNet().float() #.cuda()
    criterion = torch.nn.MSELoss(reduction='mean')
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.5, weight_decay=5e-04) # you can play with momentum and weight_decay parameters as well
    
    for epoch in range(max_epoch):
        print('\n epoch '+ str(epoch))
        train(epoch, model, criterion, optimizer, train_loader)
        test(model, test_loader,criterion)
    
''' Train your network for a one epoch '''
def train(epoch, model, criterion, optimizer, loader): # you are free to change parameters
    #print('train is called')
    accuracies = 0
    losses = 0
    model.train()
    correct = 0
    total = 0
    for batch_idx, (data, labels) in enumerate(loader):
        # TODO:
        # Implement training code for a one iteration
        
        #forward
        optimizer.zero_grad()
        data = data.float() #.cuda()
        labels = labels.float() #.cuda()
        y_pred = model(data)
        
        #loss
        loss = criterion(y_pred, labels)
        
        #update
        loss.backward()
        optimizer.step()
        
        #performance
        #y_pred = y_pred.cpu().numpy()
        y_pred = y_pred.flatten()
        total = len(y_pred)#labels.size(0)
        
        y_2 = torch.zeros(len(y_pred))
        y_2[y_pred>=0.5] = 1
        y_2 = y_2.int()
        correct = 1*(y_2 == labels.int()).sum().item()
        correct = (y_pred == labels).sum().item()
        accuracy = correct / total
        accuracies += accuracy 
        losses += loss
        
        #info = 'iteration: ' + str(batch_idx) + '  loss: ' + str(loss.item()) + '  accuracy: ' + str(accuracy)
        #print(info)
        #sys.stdout.flush()
        #sys.stdout.write('\r iteration: ' + str(batch_idx) + '  loss: ' + str(loss.item()) + '  accuracy: ' + str(accuracy))
        
        
    accuracies = accuracies/(batch_idx+1)
    losses = losses/(batch_idx+1)
    print('training accuracy '+str(accuracies))
    print('training loss '+str(losses.item()))

''' Test&Validate your network '''
def test(model, loader,criterion): # you are free to change parameters

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, labels) in enumerate(loader): # this loop will iterate once only
            data = data.float() #.cuda()
            labels = labels.float() #.cuda()
            
            #forward only
            y_pred = model(data)

            #loss
            loss = criterion(y_pred, labels)
        
            #assessy_pred = y_pred.flatten()
            y_pred = y_pred.flatten()
            y2 = torch.zeros(len(y_pred))
            y2[y_pred>=0.5] = 1
            y2 = y2.int()
            total = labels.size(0)
            correct = (y2 == labels.int()).sum().item()
            acc = correct / total
            
            print('validation accuracy '+str(acc))
            print('validation loss '+str(loss.item()))
            #return acc
        
        
main()

In any case, running the code produces the following output:

epoch 0
training accuracy 0.0
training loss 0.24466806650161743
validation accuracy 0.674
validation loss 0.23942209780216217

 epoch 1
training accuracy 0.0
training loss 0.23551715910434723
validation accuracy 0.674
validation loss 0.23317639529705048

 epoch 2
training accuracy 0.0
training loss 0.2304917573928833
validation accuracy 0.674
validation loss 0.22967855632305145

 epoch 3
training accuracy 0.0
training loss 0.22782883048057556
validation accuracy 0.674
validation loss 0.22740887105464935

 epoch 4
training accuracy 0.0
training loss 0.22573482990264893
validation accuracy 0.674
validation loss 0.22585797309875488

 epoch 5
training accuracy 0.0
training loss 0.22470112144947052
validation accuracy 0.674
validation loss 0.22480493783950806

 epoch 6
training accuracy 0.0
training loss 0.22381442785263062
validation accuracy 0.674
validation loss 0.22404111921787262

 epoch 7
training accuracy 0.0
training loss 0.22296249866485596
validation accuracy 0.674
validation loss 0.22347119450569153

 epoch 8
training accuracy 0.0
training loss 0.222427099943161
validation accuracy 0.674
validation loss 0.2230590283870697

 epoch 9
training accuracy 0.0
training loss 0.22189371287822723
validation accuracy 0.674
validation loss 0.22275172173976898

 epoch 10
training accuracy 0.0
training loss 0.22178654372692108
validation accuracy 0.674
validation loss 0.2225237935781479

 epoch 11
training accuracy 0.0
training loss 0.221751868724823
validation accuracy 0.674
validation loss 0.22235801815986633

 epoch 12
training accuracy 0.0
training loss 0.22149959206581116
validation accuracy 0.674
validation loss 0.22224263846874237

 epoch 13
training accuracy 0.0
training loss 0.22099249064922333
validation accuracy 0.674
validation loss 0.22214549779891968

As you may observe, nothing is changing. I really can’t seem to spot the flaw. I do suspect I may not be using zero_grad correctly but I’m really not sure. I have tried several learning rates: 0.1, 0.01, 0.00, 0.0000001 and diffrerent batch sizes: 4, 5, 10, 25 but nothing changes. Any advice would be much appreciated. Please help.

You can use lr_scheduler.ReduceOnPlateau

https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

You can read more about it in the above link

Hi Taher!

There are several issues here. See my in-line comments, below.

For binary classification, you should be using BCEWithLogitsLoss
and get rid of the sigmoid() at the end of your model. (You could
also keep the sigmoid() and use BCELoss, but with greater numerical
instability.) MSELoss isn’t the right choice for binary classification.

As an aside, your rgb images have information in the color. So it
would make sense not to convert them to grayscale. But to make
use of the color information, you would have to create a model
that did make use of the color – typically by having it accept inputs
with a channel dimension for the three color channels.

As mentioned above, get rid of the sigmoid().

As mentioned above, you should use BCEWithLogitsLoss.

Your labels should, indeed, be floats for use with BCEWithLogitsLoss.
By the way, what type were they before you called float()?

Also, if you continue to have problems, you should check what values
they take on. They most likely take on values of 0.0 (the “class label”
for class-“0”) and 1.0 (for class-“1”). But they could be values
(understood to be probabilities) that range over [0.0, 1.0]. In any
event, whether discrete or continuous, they should be floats.

This is a reasonable way to convert y_pred (the output of your model)
to yes-no (0-1) predictions, suitable for computing your accuracy.

Note, this is true if y_pred in made up of probabilities that range
over [0.0, 1.0], as will be the case if you have the final sigmoid()
in your model. However, if you get rid of the sigmoid() and use
BCEWithLogitsLoss (as you should), your outputs will be raw-score
logits that range over (-inf, inf), and you should use a threshold
of 0.0, rather than 0.5, that is:

# y_2[y_pred>=0.5] = 1   # change this
y_2[y_pred>=0.0] = 1     # to this

Also, your accuracy code it somewhat roundabout, and could be
tightened up. But that is left as an exercise for the reader …

This, however, is a problem. After calculating correct in a sensible
way, you immediately overwrite it with the exact-equality-test version
that is incorrect. So your (train) accuracy will be wrong in the same
way as it was in your first version.

I see that you’ve gotten rid of the factor of 100 you had in your first
version. I suppose that you are now calculating your accuracy as a
“fraction” where it had been a “percentage” before. No matter, but it
makes comparing this version with the first more confusing.

As an aside, in test() you are calculating accuracy correctly, and
not making the exact-equality-test mistake you kept in train().
This explains the difference between your training accuracy and
validation accuracy in your results below.

Redo this with BCEWithLogitsLoss and the correct accuracy
calculation.

Your use of optimizer.zero_grad() is correct.

After you get your code more-or-less working, it does make sense
to try different learning rates. Note, that a learning rate of 0.00
doesn’t make sense – it multiplies the gradient by 0.00, in effect,
doing nothing. Also, 0.0000001 seems awfully small.

And experimenting with different batch sizes makes sense, as well.

Good luck.

K. Frank

In my case, I was facing the same error. On my laptop without GPU the training was fine. When I tried on GPU the model didn’t change the accuracy and loss after the first epochs. I was using nn.CrossEntropyLoss() with Adam.
Changing Adam with SGD worked for me.