MLP regression model always output same value (approaching zero)

Hi, I am writing a simple MLP model, but the output of the model is always the same no matter what the input is, and also each element of the output vector approaches zero.

Here is my model:

class MLP(torch.nn.Module):
    def __init__(self, D_in, D_out):
        super(MLP, self).__init__()
        self.linear_1 = torch.nn.Linear(D_in, 1000)
        self.linear_2 = torch.nn.Linear(1000, 1500)
        self.linear_3 = torch.nn.Linear(1500, 1000)
        self.linear_4 = torch.nn.Linear(1000, 750)
        self.linear_5 = torch.nn.Linear(750, 500)
        self.linear_6 = torch.nn.Linear(500, 250)
        self.linear_7 = torch.nn.Linear(250, D_out)
        self.sigmoid = torch.nn.Sigmoid()
    
    def forward(self, x):
             
        x = self.sigmoid(self.linear_1(x))
        x = self.sigmoid(self.linear_2(x))
        x = self.sigmoid(self.linear_3(x))
        x = self.sigmoid(self.linear_4(x))
        x = self.sigmoid(self.linear_5(x))
        x = self.sigmoid(self.linear_6(x))
        y_pred = self.linear_7(x)
     
        return y_pred

I tried normalized my data before feeding them into model, I also tried make the model simpler, but it is still not working… my input dimension is 450 and output dimension is 120.
Can anyone give any suggestions?

I have tested my data with the MLP model integrated in scikit-learn package, it worked fine, so the problem is not with my data, but the model I built…anyone has suggestions?..

Add non-linearity (like ReLU) between your linear layers. All the linear layers simply stacked together is just equivalent to one linear layer.

There are sigmoid layers between each linear layer in my model.

Can you share a few example data points to test this?

Hi, here is an example of the data:
Input (450 dimensions):
-0.735818388214886 0.149285303851154 -0.578067021181189 1.13708538016547 0.767335229025866 -1.00972711021682 1.44431640273288 -0.303122528928313 -1.72468227791404 0.793556029576177 0.784100630718448 0.586480416291804 -0.915689110003437 0.869155049096896 -0.327297531852296 -1.52501241901532 -1.03276643390661 -0.473968816820132 -0.295609810072135 -0.515859456805497 -1.13147975550565 1.02107095888883 -1.59113754353026 -0.00599659576347787 0.113807615216029 1.77224151085492 -1.18130844786379 -0.546862215529291 1.33021374434445 0.544488914484188 0.663008310935703 -0.858810844840316 -0.687197743326812 -0.866979526762939 -1.06245366415692 -0.469133495436108 0.315262120088441 1.51910537341530 0.535432344697486 -0.802471347416508 0.743077742553866 0.227906602985508 0.865005599397846 -0.208548061188331 0.0610479314247161 -0.296093226018993 0.836297939209928 -0.480363187213705 0.209402204752636 -0.144933987024996 -0.206533549880378 0.946925786032395 -0.537818444704563 -0.950879565089832 -0.558620082910992 -1.76818792251994 -0.325854266846343 0.596675452791685 -1.53441716013984 -0.0939427711905128 0.240581130309243 1.37753811785750 -0.448188963430959 1.01632917448316 -1.98344040443072 1.43189848114174 0.00380709676953635 0.731040941080028 -0.418491596787389 -1.45116140048975 -0.0757198878866544 0.906440506235143 -0.195489176181480 0.474885044433439 -1.75296645099251 0.0537898740590009 0.166840316185414 -0.505101220206424 -1.10305805337000 2.19016304827521 -1.64267059816572 -0.846608492414494 -0.358042930737871 0.927555101990779 -0.756696277713888 0.552364304352916 -0.377732413206058 1.46947147354448 -0.772770419104887 -0.741050050782310 1.31761816206993 1.51466053722297 0.939025985378438 -0.245934698986654 1.55397168655323 -1.31062933778198 -0.218447634575306 2.11682833593488 -0.690610835809335 1.34565879808533 0.499611055253914 1.08932242212239 -0.165717938893029 0.590346641605538 1.37617634914702 0.329118285442481 -2.33539214142729 0.848313901922739 -0.0194600743277051 0.645632341231397 1.77066412774636 -0.547402479954773 1.24247230872903 1.13078942868089 2.36441975954544 0.343426659187261 -1.35464768225976 0.486124789352450 0.338884799359261 1.01189291221377 -0.158799837350974 1.12447248143180 -0.918438770221236 0.0872829492529977 -0.452961972932257 2.06814277340707 0.410710144712870 0.0107664210118090 0.483238106571439 -0.861855465344864 -0.429852428427828 0.0846469901005720 0.912335340365856 1.32693787964457 -0.0425986801530204 -0.425727244814729 0.0881495167935712 -1.18746456014654 0.510306793511959 0.422295192312096 -0.418692851559872 0.130512265453763 0.288796882468738 1.07626440985192 -1.67555519143935 -1.51063819598542 -1.12023297875842 -0.578940433520568 0.416479009939568 1.41211309181935 0.272852582995746 0.787582140256920 -0.164531828814356 -0.892619457089434 2.35387313335990 2.15444265679402 0.853564388998705 0.773964231932885 0.392984982417099 0.413719284040007 0.408039592214066 0.179285634255306 1.18993136068326 0.346673640039687 0.849820555281563 0.663293846728990 2.37788541588976 -0.410419910437476 1.85072878124200 0.0945196445828797 0.190227044494937 -0.00862260862497045 -2.33048007706741 0.350072318185885 1.77538976014696 0.0105123358423785 -0.797439870313351 0.789658533392219 -0.0598661168735266 -0.512509292829113 0.803342888759277 0.515296074529683 -0.0199060957481239 -0.203846444532534 -0.363403865186734 -0.284998855547540 -0.870345165251578 0.0255259175842855 1.07548032372644 0.0375875009112331 0.0408941606327486 -0.118744108599266 1.87274416782378 -0.475182333652753 0.421003048629767 1.52685393347522 0.303303793119889 -1.71329299633793 -1.76254559768078 0.760726416738919 -0.0369942976896756 -1.53714846782517 0.532768205421902 0.613476076198292 -1.24587083978263 -0.0908267382857597 -1.04976688505827 -0.544466696790154 0.616182440954475 0.342816759355589 0.144503438142935 -1.48705314240430 -0.447520691642277 -1.53098062179696 -0.00626829017518429 -1.49407926767421 -0.957375697157989 0.755564452147370 -0.745437739000280 0.143221812338554 -0.961105606309910 -1.19637145356054 -0.801846004846102 -1.98341402239006 -0.141406185781072 0.284864088404030 -2.22626832605181 1.52653563599279 2.13960852388333 1.53970999012204 1.15620581147065 -0.595187848735976 0.678270451170598 -1.52005318705716 -0.157528454513269 -1.23280723924239 0.752606156694844 -1.04186084853909 0.608394332176887 0.689406908096181 -0.256367529050619 -0.280540634731785 -0.307800029661493 -0.527898501075078 1.08610152723436 -0.481592519129528 -0.212290828324039 0.126002734363314 -0.747989928955777 0.489271117537880 -0.186739355730084 0.533719515040723 -0.544797495851650 1.50339005591187 0.541926827188491 -0.243197671411433 -1.79788547055804 -0.534596968093643 -0.315525814296939 2.01761951163463 -0.349775082407994 -1.32697498989970 0.999048058049433 0.0692273613508713 -0.00892051731983100 -0.136649250302277 -0.515509695060058 0.159847278104185 1.31922530844349 0.679973155251442 0.297837499659531 -0.863219090612890 -0.902593836350390 0.00547590038320574 -0.0239774096454648 -0.771092586061540 0.0347153321517812 -1.46902769034476 -0.989314484344415 -1.38564595107017 -0.168557615557302 -2.14207620366500 1.51409818597145 0.589519340162154 0.489306004608671 -0.0990063450521083 0.569405655417508 0.348377463731985 0.351944229498122 -0.929034725173021 1.01489841483581 -1.20219322184725 0.774219873135912 1.57064016547379 1.29844172635315 0.0123181847036591 1.60857482218866 0.306305598706328 -0.163473773068933 -0.863119136540574 0.577781460723887 -0.583891061960679 0.579100742187932 -0.261218616340131 1.35914586993769 0.651286722049995 -0.445858544568678 1.98711707884207 0.855070779457029 0.324888254251314 -2.18994697806142 -1.46217532162863 0.692398681057095 0.890738573207726 0.953608201959669 -2.17322386899052 2.09195117427766 -1.02683680479155 -1.59749397470924 -1.61398938735634 1.14708631109452 -1.48190439716087 -0.958084110040263 1.73699980384352 0.171497129806649 -0.375824884166825 -0.208896664000074 -0.232916096907624 -0.999989068952020 1.43196331885695 1.05333800160815 -1.87769216621010 -0.373463909673348 -0.536355318780433 -0.0528307745492846 -0.140475950318341 0.355150500133399 1.72268869924568 0.916933214616019 0.663657095685507 0.997434965650828 0.808879108223415 -0.101854813383192 1.06833962008493 -1.30872729736308 -0.382603177176991 -2.50598994166440 0.127192043657772 -0.317857101989867 0.0109071459772834 0.472353545947209 -2.90412235034735 0.293469599086924 0.949976471229498 1.82486915988484 -0.724280810080997 1.68405009714927 0.639574714575609 -0.867532742695344 -1.80164016593070 -0.725316475286614 -0.802243410765602 -1.58894967338598 -0.593816526590895 1.54429631044769 0.0759264897006957 -0.848576163882861 1.46855445829179 -1.12023631478264 -0.584155321314308 1.73053238157822 -0.321197545709661 0.0786746576675956 -1.43887198310287 0.102445769475108 1.14323107544532 0.305349972268648 -0.808809830969838 0.784810297125262 -0.815665596211662 -0.872831400760270 1.37296766288476 0.418719711616480 -0.948158182619194 0.878295772904781 1.31898034581479 -0.817559745492998 0.831082951440480 0.949303010793698 2.01020656366090 -0.432122743923586 0.925987043019444 0.546283688870168 -0.383790650143911 3.06659138404548 0.0999160740625315 -1.10834298087316 0.368918512186226 0.508676688719668 0.353856030151058 -1.08901594591339 0.482262488823247 -1.37520937859426 -0.627315838862955 0.781117715520164 -0.0684018950749005 -0.424245568746243 0.688157912579803 0.0794974830211469 -0.593836626710330 -0.138218083932841 -0.625889046108883 -0.464703705404994 -0.617654342027112 0.461580452263993 0.0478175605105834 -0.215060595309493 -1.05892181106604 -0.451188490131863 0.0599069991116341 0.455182505311400 -1.53276713244221 0.512244554956416 -1.27103120201129 -0.950906642491205 0.182339506439061 1.08338628771804 -0.536682474414237 -1.06959635212258 0.459856625436814 -0.506621610555142 0.202810486436882 0.505627177830883 -1.35180754744127 1.28325363856324 0.745357456683776 -0.546185670707208 0.597550026378408 1.74778911588788 0.192450400164104 -0.240390964552302 0.112523767557991 -2.81556039054814 -0.218262017995826 -1.03440307996590 -0.350841063479356 -0.520668669845006 0.587456453857042 -0.664428434754110 -0.888170611420252

Output (120 dimension):
0.532300212681870 0.179998690554949 -0.0257466843027034 -0.205882453641676 0.591182642419381 -0.0886731841952033 0.263522930131506 -0.232009699062619 -0.568828657246618 0.0574685230173592 -1.60284642144261e-07 0.0662690547702825 -0.0564374703109907 -0.407917260673937 -0.343127227975678 0.180063677196068 -0.492261357637083 -0.336502362526927 0.0226164378018748 0.391798473050306 -0.511955179087885 -0.400479597369964 0.0815137946320438 -0.429874559130754 -0.165088072840650 -0.482222177400369 0.346141730818296 -0.470105832986738 0.0148054113129904 -0.248413666577270 7.85676120099218e-06 -0.109881131720543 0.487259200095905 -0.228457716311294 0.239037570703603 0.206581423495658 0.0734277499113765 -0.292353134135843 0.236769905240832 0.418475528074645 0.343941269578648 -0.242578842189166 -0.403348923623755 0.0476801596732162 0.514682257205765 0.291071959078389 -0.521966344699764 -0.0305610147893470 -0.0753008830516899 -0.275737840822758 1.73026709799811e-05 0.451837795234245 -0.377204036558573 -0.102295416442940 0.189983419757157 0.173017195991341 0.209241386814994 0.121568362397577 0.482217214315650 -0.489043431180480 0.685569308661741 0.0754196779260324 0.264387863585074 -0.460677488823237 -0.0655079802930742 0.0564285529922251 0.0224475949085858 -0.806614168786861 -0.246844101731706 -0.438966486336961 1.68718295751309e-05 0.262910868990014 -0.00652996986393479 0.320597507699597 0.256793848901400 -0.0761349976903453 -0.0714002589137728 -0.0600250235381107 0.250188193002387 -0.122352971769150 0.268844153278181 -0.194310232938437 -0.525340063393711 -0.271752750190912 0.0797757867755889 -0.222341163148859 0.0545291377135724 0.103619676878438 -0.233122479602490 -0.211476963323040 1.24631248855083e-05 -0.130492588452597 -0.520165791152049 -0.483149516929881 0.112854850291193 0.106650729585523 -0.399795351834588 0.368224983765443 -0.559589475177480 -0.466738283723279 0.684576301742124 0.00388587072033813 -0.532243187752000 0.0542254933565711 -0.186914327254024 -0.295018353999371 0.00468540340194501 0.238713708177180 0.131293241163625 -0.321632809816311 -1.10141122226793e-06 0.278443194578431 0.312817519725799 0.0718265446824297 0.172149761992720 0.406982900273825 0.259607818592667 0.539530859757522 0.454254882552196 -0.0692824383104469

And the output I got by testing the model is like (This output remains same no matter what my input is):
0.0032217987 -0.0032565966 -0.0021627229 -0.22697754 -0.0024142414 0.0026991144 -0.0015234398 -0.0028859563 0.0021582916 -0.21696138 -6.2100589e-06 0.0091556832 0.0012200102 -0.0035693087 0.0046808752 0.0068705790 -0.0010050870 0.0015123896 0.0037639998 0.00036380813 0.00043297932 -0.00035620201 -0.0032979734 -0.23171432 0.0045496225 -0.0012066588 -0.0021106657 -0.0041744113 -0.012195184 -0.21638247 2.4028122e-06 0.0059865061 -0.0042346753 -0.00071774051 -0.0026111864 -0.0018474497 -0.00078333169 -0.0054079369 0.0077128373 0.0016663412 0.0032119900 -0.0062416419 -0.0021390468 -0.22433044 -0.0027336776 0.0028707050 0.0022469088 -0.0034709517 -0.0028906241 -0.21881813 5.7823956e-05 0.0032278486 -0.0022484893 -0.010667327 -0.00026572868 -0.0018052831 0.0012878384 0.0035010017 -0.0034696758 0.0034123622 -0.0081416154 0.00037829392 0.0034989491 -0.22207867 0.010135166 0.0018704161 0.0050993972 0.0022409894 -0.0079786349 -0.22068146 -1.6111881e-06 -0.0033761226 0.00018132944 0.0041641518 0.0018924810 -0.0017091706 -0.0061378721 0.00092485966 -0.0015861169 0.00036317296 0.0098521076 0.0076599456 -0.0084095690 -0.22120586 -0.010473695 -0.0077724494 -0.0068498887 -0.0096608894 -0.0033896575 -0.22154866 0.00010513514 0.0015144497 0.0026684590 0.0033703633 -3.2403506e-05 0.0055518188 0.0079380125 -0.0090062730 0.0057077296 -0.0058418475 0.0070619956 0.0019653030 0.0018671704 -0.23226789 0.0086756833 -0.00078263949 0.0070048124 -0.0047388561 -0.0073986426 -0.21967286 -7.0333481e-05 0.0058467574 -0.0092798918 -0.0033651199 0.0044662654 -0.0027292185 -0.0056617074 0.0037126155 0.0060428437 0.0012843013

Thanks !

That’s a “brute force” network, so with this depth it will be learning slowly by design, even when everything is correct. But here your problem is likely the use of saturating Sigmoid, try non-saturating functions: [leaky_]relu, selu (aka self-normalizing network).

Hi, thanks for your reply, actually I tried Relu as activation as well as reducing the depth of the network, but this still happened, I am thinking about whether I am getting this weird output due to the high dimensions of my output. Any comments?

What do your loss and accuracy curves look like?

Some potential issues:

  • optimizer
    • make sure to pass all of your model’s parameters to the optimizer: optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum=0.9) as an example.
    • make sure to clear the optimizer at the start of running each batch: optimizer.zero_grad()
  • hyperparameters
    • make sure your hyperparameters have an appropriate size: learning rate, batch size, weight decay, etc, and when in doubt use the defaults. If you’re using L1/L2 regularization, make sure your weight decay value is not set too high. If it is, your model might instead be optimizing for reducing the values of its parameters to 0 instead of achieving good predictions.
  • loss function
    • make sure you’re using the correct loss function for your use case (MSELoss I imagine).
  • monitor the gradient, see if it is not vanishing/exploding.

Hi stroncea, thanks for detailed explanation! Yes I am using MSE as my loss function, and I checked my code many times to make sure the optimizer is set correctly and the gradients are cleared at beginning of each batch.
I have 3500 training points in total and I set my batch size to 100. I also have a validation set which contains 750 data points.
Below is the training history:
fig
You can see that the loss quickly dropped to around 0.1 and it seems after this the model’s weights did not change anymore (so the output is always the same for all inputs).
Another thing I notice is that, if I creat a matrix A and calculate the MSE between A and y_test_true, the result is about 0.1, so this explained that all my outputs are approching zero (the model is being lazy and just set its weights to make all the outputs zero). But I have no idea how to improvee this.

Thanks

@Yuchen_Mu can you copy in your training loop and your parameters for it?

Hi, it is like this:

loss_function = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(total_epoch_number):
    loss_array_for_current_epoch = []
    overall_training_loss_for_current_epoch = 0
    for batch_num in range(number_of_training_data // N):

        current_batch_training_input = training_set_input[batch_num*100:batch_num*100+100, :].to(device)
        y_pred = model(current_batch_training_input)

        # Compute loss for current batch
        current_batch_training_output = training_set_output[batch_num*100:batch_num*100+100, :].to(device)
        loss = loss_function(y_pred, current_batch_training_output)
        
        
        
        
        # Zero gradients, perform a backward pass, and update the weights.    
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        loss_array_for_current_epoch.append(loss.item())
        
    # Use the average loss of all batch as the overall training loss for this epoch.    
    overall_training_loss_for_current_epoch = np.average(loss_array_for_current_epoch)   
    
    # Validation process at current epoch
    with torch.no_grad():
        y_pred_val = model(validation_set_input)
        val_loss = loss_function(y_pred_val, validation_set_output)

What kind of results did sci-kit learn give you for MSE compared to this?

Also, have you tried adding some regularization, either Dropout of weight decay?

I did try this with MLPRegressor in scikit-learn package, the result is similar… (testing outputs remain same no matter what input is)
Yes, I tried adding dropout layer after activation layer and the results improve a little but not so much, and also when applying the dropout layer, the MSE loss, as expected, is a little higher compared with no dropout.

Have you tried using a single hidden layer (perhaps wider)? If that works better, the problem is with excessive depth (so gradients are either tiny or move in all directions by mini-batches). If not, either optimizer needs tweaking or there is some error (loop looks ok though).

Hi, I did try to use only one hidden layer, the result has a little improvement but not so much, also in this case the model is obviously underfitted.

You should be able to overfit a single layer MLP, by increasing width and/or learning rate. Then revert towards your deep model and you’ll see how it stops working.

Hi, do you mean increase the number of neurons in the hidden layer? In the model I originally used, it had 1000 neurons in the first hidden layer, if I build a MLP which contains only one hidden layer, should I increase this number even larger?
Thanks.

Yea, use a lot of neurons (10-100k maybe) to verify that your training loss can approach zero (overfit), which would indicate no coding errors and adequate optimizer params. Regarding optimizers, try Adadelta or Rprop to avoid lr tuning.

Hi, let me try your advice and see whether I can get something.
Thank you!