Pytorch NN regression model does not learn

I’m very new to pytorch and I’m very stuck with model converging. It seems to me it is not learning since the loss/r2 do not improve.

What I’ve checked/tried based on the suggestions I found here.

  1. changes/wrote from scratch loss function
  2. set “loss.requires_grad = True”
  3. tried to feed the data without dataloader / just straight manual batches
  4. played with 2d data / mean pooled data!!! I got decent results for mean pooled data in Random Forest and SVM regressor, but not in NN, which confuses me and makes me think that the data is OK and the net is NOT ok!
  5. played with learning rate, batch size

etc.

About the data: input is bert embeddings from letter-sequence, each data point is 1024 features 43 rows (for Conv1d I transpose it to 1024*43) Total >40K data point in train, batch size=64

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F ## relu, tahn
import torch.utils.data as DataLoader # helps create batches to train on
from scipy.stats import pearsonr
import numpy as np
import torch.utils.data as data_utils
torch.set_printoptions(precision=10)


#Hyperparameters
learning_rate=0.001
batch_size = 64
num_epochs=100


data_train1 = torch.Tensor(data_train)
targets_train1=torch.Tensor(targets_train)

dataset_train = data_utils.TensorDataset(data_train1, targets_train1)
train_loader = DataLoader.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)


class NN (nn.Module):
    def __init__(self):#input_size=43x1024
        super(NN,self).__init__()
        self.layers = nn.Sequential(
            nn.Conv1d(1024, 512, kernel_size=4), #I tried different in and out here
            nn.ELU(),
            nn.BatchNorm1d(512),
            nn.Flatten(),
            nn.Linear(512*40, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )
        
    def forward(self, x):
        return self.layers(x)

torch.manual_seed(100)

#Initialize network
model=NN().to(device)

#Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate) #to check 


#Training the model
metric_name='r2'

for epoch in range(num_epochs):
    score=[]
    loss_all=[]
    print(f"Epoch: {epoch+1}/{num_epochs}")

    model.train()

    for batch_idx, (data, targets) in enumerate(train_loader):
        data=data.to(device=device)
        targets=targets.to(device=device)

        optimizer.zero_grad()
        
        #forward
        predictions=model(data)
        
        loss=criterion(predictions,targets).to(device)

        loss.requires_grad = True

        #backward
        loss_all.append(loss.item())
        loss.backward()
    
        #gradient descent or adam step
        optimizer.step()
        
        #computing r squared
        output=predictions.detach().cpu().numpy()
        target=targets.detach().cpu().numpy()
        output=np.squeeze(output)
        target=np.squeeze(target)
    
        score.append(pearsonr(target, output)[0]**2)

    total_score = sum(score)/len(score)
    print(f'training {metric_name}: {total_score}, mean loss: {sum(loss_all)/len(loss_all)}')

Output for (10) first Epochs:
Epoch: 1/100 training r2: 0.0026224905802415955, mean loss: 0.5084380856556941
Epoch: 2/100 training r2: 0.0026334153423518466, mean loss: 0.5082988155293148
Epoch: 3/100 training r2: 0.002577073836564485, mean loss: 0.5085703951569392
Epoch: 4/100 training r2: 0.002633483899689855, mean loss: 0.5081870414129565
Epoch: 5/100 training r2: 0.0025642136678393776, mean loss: 0.5083346445680192
Epoch: 6/100 training r2: 0.0026261540869286933, mean loss: 0.5084220717277274
Epoch: 7/100 training r2: 0.002614604670602339, mean loss: 0.5082813335398275
Epoch: 8/100 training r2: 0.0024826257263258784, mean loss: 0.5086268588042153
Epoch: 9/100 training r2: 0.00261018096876641, mean loss: 0.5082496945227619
Epoch: 10/100 training r2: 0.002542892071836945, mean loss: 0.5088265852086478

Response is float64 in range (-2,2).

With response scaled to float64 in range [-1,1] and tanh still not converging. I feel like something general is missing. BTW, when I do non shuffled batches and straight-forward data for batches (first batch with indices 0-63, second bath 64-127 etc) I get the same score results in each of epochs!

Tried to add (2) more sequences with Conv1d, BatchNorm1d, ELU (1024->512, 512->256, 256->128, kernel size 4) and the result is worse:( it is not learning at all!

I also tried to do mean pooled for data to have only 1024 for the input, with the same poor result for NN.

For me it seems like I’m missing something very general and the model just do not learn.

Hope you can help! Thank you in advance for your time!

Hello,
Since the loss is changing 0.0026 to 0.0024 I think it’s learning even if it’s not decreasing in te begining but maybe in the end itt will decrease, if not so try tot ake more epochs (more an 100) to give enough time for convergence.
otherwise, maybe didn’t understand you question.

Thank you, I will try to run 10K epochs today to see if it converges.
Yes - my question was why the model is not converging (I expect to see r^2 in about 0.5-0.6 range as I got this result from Random Forest)

no improvement on 10K epochs as well ;(

I think a classic test to try here is to decrease the amount of data in the dataset and verify that your model can overfit it.

Thank you for the idea!
But the result is the same, just ran on 10K data points.

not converging. Something wrong with it’s learning ability…

Right, what happens when you reduce this to an extreme, say just a few or even a single data point?

this is what i got:

5 data points:
Epoch: 1/100
training r2: 0.13436292977865993, mean loss: 0.33448025584220886

Epoch: 100/100
training r2: 0.1343612927439954, mean loss: 0.33448025584220886

2 data points:
Epoch: 1/100
training r2: 1.0, mean loss: 0.6390723586082458
Epoch: 100/100
training r2: 1.0, mean loss: 0.6390724778175354

1 data point:
ValueError: x and y must have length at least 2.

Interesting, I’ve tried your model code with random training data and it seems to train fine. However, one difference is that I removed the loss.requires_grad = True line as that causes

Traceback (most recent call last):
  File "tempsequence.py", line 77, in <module>
    loss.requires_grad = True
RuntimeError: you can only change requires_grad flags of leaf variables.

in my environment.
Output:

Epoch: 1/100
training r2: 0.009594063912129953, mean loss: 5.682353660464287
Epoch: 2/100
training r2: 0.16930472195690563, mean loss: 26.084057807922363
Epoch: 3/100
training r2: 0.2011665253106225, mean loss: 11.448175303637981
Epoch: 4/100
training r2: 0.09114844879538331, mean loss: 2.279003791511059
Epoch: 5/100
training r2: 0.3172596296459851, mean loss: 0.7606077138334513
Epoch: 6/100
training r2: 0.4809859426056538, mean loss: 0.5154492985457182
Epoch: 7/100
training r2: 0.5868341407534547, mean loss: 0.4187624929472804
Epoch: 8/100
training r2: 0.649767417816389, mean loss: 0.34969478007405996
Epoch: 9/100
training r2: 0.7071613647219848, mean loss: 0.29750150814652443
Epoch: 10/100
training r2: 0.7642161760963712, mean loss: 0.2450719503685832
Epoch: 11/100
training r2: 0.8191470208232381, mean loss: 0.18541160179302096
Epoch: 12/100
training r2: 0.866612613881297, mean loss: 0.13777912221848965
Epoch: 13/100
training r2: 0.9005741764866068, mean loss: 0.09871353628113866
Epoch: 14/100
training r2: 0.9297122659354237, mean loss: 0.07004990486893803
Epoch: 15/100
training r2: 0.9539838535204129, mean loss: 0.047599957906641066
Epoch: 16/100
training r2: 0.9682029463044977, mean loss: 0.03334217533119954
Epoch: 17/100
training r2: 0.9779388192414571, mean loss: 0.022128620053990744
Epoch: 18/100
training r2: 0.9855749007781102, mean loss: 0.014580676259356551
Epoch: 19/100
training r2: 0.9902337714874327, mean loss: 0.011297925255348673
Epoch: 20/100
training r2: 0.9925735608177005, mean loss: 0.007077659640344791
Epoch: 21/100
training r2: 0.9948008373417658, mean loss: 0.004619108618499013
Epoch: 22/100
training r2: 0.9967165985752503, mean loss: 0.003157338052915293
Epoch: 23/100
training r2: 0.9978223330030351, mean loss: 0.002256347268485115
Epoch: 24/100
training r2: 0.9982408967003787, mean loss: 0.001730209368361102
Epoch: 25/100
training r2: 0.9987114706485303, mean loss: 0.0013457234599627554
Epoch: 26/100
training r2: 0.9987114128846947, mean loss: 0.001346579964774719
Epoch: 27/100
training r2: 0.9985214830016457, mean loss: 0.0015366310344688827
Epoch: 28/100
training r2: 0.9979733118207817, mean loss: 0.0020749795767187607
Epoch: 29/100
training r2: 0.9976363805488251, mean loss: 0.0022758882041671313
Epoch: 30/100
training r2: 0.9969927904895611, mean loss: 0.0030411550051212544
Epoch: 31/100
training r2: 0.995648326719438, mean loss: 0.005057528265751898
Epoch: 32/100
training r2: 0.9922587923077028, mean loss: 0.008088470793154556
Epoch: 33/100
training r2: 0.9887959956127246, mean loss: 0.012593698847922496
Epoch: 34/100
training r2: 0.9831638414540375, mean loss: 0.018604541895911098
Epoch: 35/100
training r2: 0.9737390985477979, mean loss: 0.027635501814074814
Epoch: 36/100
training r2: 0.9641899244039341, mean loss: 0.038666998385451734
Epoch: 37/100
training r2: 0.951262654867677, mean loss: 0.052705009758938104
Epoch: 38/100
training r2: 0.9371287246394816, mean loss: 0.06455437233671546
Epoch: 39/100
training r2: 0.9568622813468258, mean loss: 0.04678037913981825
Epoch: 40/100
training r2: 0.9684106859661721, mean loss: 0.03428023273590952
Epoch: 41/100
training r2: 0.9769530552075255, mean loss: 0.024533933261409402
Epoch: 42/100
training r2: 0.986115178609799, mean loss: 0.01612667564768344
Epoch: 43/100
training r2: 0.9883539015220295, mean loss: 0.012954877864103764
Epoch: 44/100
training r2: 0.9924869826615559, mean loss: 0.0088661702175159
Epoch: 45/100
training r2: 0.9942897541431738, mean loss: 0.006273552269703941
Epoch: 46/100
training r2: 0.9968140855136599, mean loss: 0.0035802944257739
Epoch: 47/100
training r2: 0.997001327959991, mean loss: 0.003321810443594586
Epoch: 48/100
training r2: 0.9975052142543458, mean loss: 0.0028207608411321416
Epoch: 49/100
training r2: 0.9981577880948567, mean loss: 0.001989818636502605
Epoch: 50/100
training r2: 0.998546231725961, mean loss: 0.0016351598242181353
Epoch: 51/100
training r2: 0.9986027636978555, mean loss: 0.0015020208647911204
Epoch: 52/100
training r2: 0.9987384508660706, mean loss: 0.0014546711390721612
Epoch: 53/100
training r2: 0.9986857161310916, mean loss: 0.0015857035414228449
Epoch: 54/100
training r2: 0.9987154541868536, mean loss: 0.001486757697421126
Epoch: 55/100
training r2: 0.9986375054512535, mean loss: 0.0015756650318508036
Epoch: 56/100
training r2: 0.9983302812352135, mean loss: 0.001992172579775797
Epoch: 57/100
training r2: 0.9978964350882717, mean loss: 0.0023003785172477365
Epoch: 58/100
training r2: 0.9980009710848087, mean loss: 0.0022060492519813124
Epoch: 59/100
training r2: 0.9975635621751086, mean loss: 0.0026615302267600782
Epoch: 60/100
training r2: 0.9975992688318768, mean loss: 0.002975316012452822
Epoch: 61/100
training r2: 0.9963099659802808, mean loss: 0.004268930060788989
Epoch: 62/100
training r2: 0.9960123018852073, mean loss: 0.004779360577231273
Epoch: 63/100
training r2: 0.9953000198690384, mean loss: 0.005165318536455743
Epoch: 64/100
training r2: 0.9941095220427116, mean loss: 0.006454882866819389
Epoch: 65/100
training r2: 0.9930570804636689, mean loss: 0.007594891852932051
Epoch: 66/100
training r2: 0.9928955292326065, mean loss: 0.008498512339428999
Epoch: 67/100
training r2: 0.9911527403046312, mean loss: 0.010306075535481796
Epoch: 68/100
training r2: 0.9931718973224487, mean loss: 0.008304626011522487
Epoch: 69/100
training r2: 0.9933063535071627, mean loss: 0.00726965151989134
Epoch: 70/100
training r2: 0.993171505320739, mean loss: 0.008076088663074188
Epoch: 71/100
training r2: 0.994059443074772, mean loss: 0.0072878144710557535
Epoch: 72/100
training r2: 0.9931785372577545, mean loss: 0.007667907964787446
Epoch: 73/100
training r2: 0.9930847791765954, mean loss: 0.006984286243095994
Epoch: 74/100
training r2: 0.9955165209328762, mean loss: 0.004958275567332748
Epoch: 75/100
training r2: 0.9954397187994897, mean loss: 0.005190604671952315
Epoch: 76/100
training r2: 0.9958678910723257, mean loss: 0.005277530166495126
Epoch: 77/100
training r2: 0.9953963619106224, mean loss: 0.005750581309257541
Epoch: 78/100
training r2: 0.9963880844229783, mean loss: 0.004177943723334465
Epoch: 79/100
training r2: 0.9970984115343485, mean loss: 0.0033548544597579166
Epoch: 80/100
training r2: 0.996957118488398, mean loss: 0.0037890574349148665
Epoch: 81/100
training r2: 0.9953908147342319, mean loss: 0.004737168623250909
Epoch: 82/100
training r2: 0.9961195713550257, mean loss: 0.004688302433351055
Epoch: 83/100
training r2: 0.9964537987265749, mean loss: 0.004451523665920831
Epoch: 84/100
training r2: 0.9965007195822354, mean loss: 0.004249753299518488
Epoch: 85/100
training r2: 0.9952766414896639, mean loss: 0.005342122851288877
Epoch: 86/100
training r2: 0.9965416337287639, mean loss: 0.004103144659893587
Epoch: 87/100
training r2: 0.9954450219242742, mean loss: 0.0051249204989289865
Epoch: 88/100
training r2: 0.9967367979858002, mean loss: 0.0037255308488965966
Epoch: 89/100
training r2: 0.9961328934047874, mean loss: 0.0044315601444395725
Epoch: 90/100
training r2: 0.996738920442225, mean loss: 0.003717487008543685
Epoch: 91/100
training r2: 0.9960501740763243, mean loss: 0.004249425859597977
Epoch: 92/100
training r2: 0.9949915325389183, mean loss: 0.005676474203937687
Epoch: 93/100
training r2: 0.9944822243730989, mean loss: 0.006022317189490423
Epoch: 94/100
training r2: 0.9944785713113447, mean loss: 0.0061915624246466905
Epoch: 95/100
training r2: 0.9902922907694482, mean loss: 0.009277411510993261
Epoch: 96/100
training r2: 0.9932751325981961, mean loss: 0.008374881159397773
Epoch: 97/100
training r2: 0.9907874390780724, mean loss: 0.01135821822390426
Epoch: 98/100
training r2: 0.9889161612958226, mean loss: 0.013372064306167886
Epoch: 99/100
training r2: 0.982071001035948, mean loss: 0.0186732045840472
Epoch: 100/100
training r2: 0.984889725503154, mean loss: 0.01782958245894406

Code:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F ## relu, tahn
import torch.utils.data as DataLoader # helps create batches to train on
from scipy.stats import pearsonr
import numpy as np
import torch.utils.data as data_utils
torch.set_printoptions(precision=10)


#Hyperparameters
learning_rate=0.001
batch_size = 64
num_epochs=100

n = 1000
data_train = torch.randn(n, 1024, 43)
data_test = torch.randn(n, 1024, 43)
targets_train = torch.randn(n, 1)
device = 'cuda'

data_train1 = torch.Tensor(data_train)
targets_train1=torch.Tensor(targets_train)

dataset_train = data_utils.TensorDataset(data_train1, targets_train1)
train_loader = DataLoader.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)


class NN (nn.Module):
    def __init__(self):#input_size=43x1024
        super(NN,self).__init__()
        self.layers = nn.Sequential(
            nn.Conv1d(1024, 512, kernel_size=4), #I tried different in and out here
            nn.ELU(),
            nn.BatchNorm1d(512),
            nn.Flatten(),
            nn.Linear(512*40, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )
        
    def forward(self, x):
        return self.layers(x)

torch.manual_seed(100)

#Initialize network
model=NN().to(device)

#Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate) #to check 


#Training the model
metric_name='r2'

for epoch in range(num_epochs):
    score=[]
    loss_all=[]
    print(f"Epoch: {epoch+1}/{num_epochs}")

    model.train()

    for batch_idx, (data, targets) in enumerate(train_loader):
        data=data.to(device=device)
        targets=targets.to(device=device)

        optimizer.zero_grad()
        
        #forward
        predictions=model(data)
        
        loss=criterion(predictions,targets).to(device)

        # loss.requires_grad = True

        #backward
        loss_all.append(loss.item())
        loss.backward()
    
        #gradient descent or adam step
        optimizer.step()
        
        #computing r squared
        output=predictions.detach().cpu().numpy()
        target=targets.detach().cpu().numpy()
        output=np.squeeze(output)
        target=np.squeeze(target)
    
        score.append(pearsonr(target, output)[0]**2)

    total_score = sum(score)/len(score)
    print(f'training {metric_name}: {total_score}, mean loss: {sum(loss_all)/len(loss_all)}')

Can you verify whether the model can learn from random training data on your setup?

1 Like

oh my, I just found the line

torch.set_grad_enabled(False)

somewhere in global settings which actually caused the problem!

and yes - loss.requires_grad = True need to be removed - good catch!
Thank you for your help! I appreciate it!

2 Likes