Different loss for cpu and gpu

I am training a bilinear similarity model with a multi-label max-margin loss function. The problem I have is that the model trains perfectly fine on the cpu and returns expected results (close to the baseline). However, the model returns completely incorrect results when I train the system on the gpu. How do I debug this problem?

This is the paper that I am trying to replicate - https://www.cs.northwestern.edu/~ddowney/publications/zheng_aaai_2018.pdf

Any help would be appreciated!

Hi,

This is quite unexpected. Are you sure that you perform the same initializations/steps both on cpu and gpu?
Can you pinpoint more precisely where the two approaches diverge?

Yes. I am performing the same initialization for both cpu and gpu. Here’s the architecture of the model -

BiLinear(
  (input_encoder): BiLSTM_Attention(
    (dropout): Dropout(p=0.5, inplace=False)
    (feature_embedding): Embedding(600000, 50, padding_idx=0)
    (embeddings): Embedding(
      (embeddings): Embedding(37119, 300)
    )
    (left_bilstm): BiLSTM(
      (lstm): LSTM(300, 100, bidirectional=True)
    )
    (right_bilstm): BiLSTM(
      (lstm): LSTM(300, 100, bidirectional=True)
    )
    (attention): AttentionMLP(
      (linear_1): Linear(in_features=200, out_features=100, bias=True)
      (linear_2): Linear(in_features=100, out_features=1, bias=True)
      (tanh): Tanh()
      (softmax): Softmax(dim=None)
    )
  )
  (label_encoder): LabelEncoder(
    (label_encoder): AverageLabelEncoder(
      (embeddings): Embedding(99, 300)
      (label_layer): Linear(in_features=300, out_features=20, bias=False)
    )
  )
  (input_layer): Linear(in_features=550, out_features=20, bias=False)
)

and the code that I use to train the model is:

    model = BiLinear(config).to(config.device)

    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    loss_fn = nn.MultiLabelMarginLoss()
    all_labels = list(config.label_to_idx.values())

    for epoch in range(config.epochs):
        print("Epoch No.", epoch+1)

        total_loss = torch.Tensor([0.]).to(config.device)

        model.train()
        idx = 0
        for x in tqdm(training_data):
            optimizer.zero_grad()
            loss = Variable(torch.Tensor([0.]))
            score_main = model(x['mention_tokens'], x['left_tokens'], x['right_tokens'], x)

            loss = loss_fn(score_main, x['labels'])
            total_loss += loss.item()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()
            idx += 1

        print("Total Loss = ", total_loss)

After 1 epoch, the train loss for cpu is somewhere around 200 and the train loss for gpu is around 50492384. I’m using the train loss as a sanity check.

Your code looks ok from a high level read.
Can you pinpoint when this difference appears? Is it before training? Or one goes down and not the other? Or one diverge and not the other?

PS:
You can initialize your loss accumulator with just python number: total_loss = 0. No need to wrap these.

The difference appears during training. Here’s a look at the total_loss for each batch while training.

For gpu -

batch no 1 loss =  1.48519766330719
batch no 2 loss =  1.349387288093567
batch no 3 loss =  1.2793351411819458
batch no 4 loss =  1.1422659158706665
batch no 5 loss =  1.065197229385376
batch no 6 loss =  0.8988766074180603
batch no 7 loss =  0.795573353767395
batch no 8 loss =  0.7166022062301636
batch no 9 loss =  0.5900906920433044
batch no 10 loss =  0.5402706265449524
batch no 11 loss =  0.5099382400512695
batch no 12 loss =  0.5270498991012573
batch no 13 loss =  0.4964383840560913
batch no 14 loss =  0.4695017337799072
batch no 15 loss =  0.4864051938056946
batch no 16 loss =  0.5001940131187439
batch no 17 loss =  0.5504444241523743
batch no 18 loss =  0.4995945990085602
batch no 19 loss =  0.6236541867256165
batch no 20 loss =  0.5803393125534058
batch no 21 loss =  0.6943713426589966
batch no 22 loss =  0.6918058395385742
batch no 23 loss =  0.7231294512748718
batch no 24 loss =  0.8104425072669983
batch no 25 loss =  0.8485711216926575
batch no 26 loss =  0.9337277412414551
batch no 27 loss =  1.040302038192749
batch no 28 loss =  1.0833561420440674
batch no 29 loss =  1.2168757915496826
batch no 30 loss =  1.3513448238372803
batch no 31 loss =  1.3978110551834106
batch no 32 loss =  1.5755630731582642
batch no 33 loss =  1.643385887145996
batch no 34 loss =  1.8431206941604614
batch no 35 loss =  2.0576276779174805
batch no 36 loss =  2.138352155685425
batch no 37 loss =  2.420424222946167
batch no 38 loss =  2.6673426628112793
batch no 39 loss =  2.9037210941314697
batch no 40 loss =  3.0415878295898438
batch no 41 loss =  3.3367340564727783
batch no 42 loss =  3.61075758934021
batch no 43 loss =  3.875593900680542
batch no 44 loss =  4.022735595703125
batch no 45 loss =  4.232641220092773
batch no 46 loss =  4.469114303588867
batch no 47 loss =  4.6978302001953125
batch no 48 loss =  4.908476829528809
batch no 49 loss =  5.408933639526367
batch no 50 loss =  5.91743278503418
batch no 51 loss =  6.056540489196777
batch no 52 loss =  6.3848652839660645
batch no 53 loss =  6.919095993041992
batch no 54 loss =  7.30460262298584
batch no 55 loss =  7.652694225311279
batch no 56 loss =  8.113967895507812
batch no 57 loss =  8.591318130493164
batch no 58 loss =  8.899323463439941
batch no 59 loss =  9.512104034423828
batch no 60 loss =  9.753629684448242
batch no 61 loss =  10.346807479858398
batch no 62 loss =  11.033885955810547
batch no 63 loss =  11.304972648620605
batch no 64 loss =  11.373547554016113
batch no 65 loss =  12.382369041442871
batch no 66 loss =  12.86313533782959
batch no 67 loss =  13.56275749206543
batch no 68 loss =  13.181768417358398
batch no 69 loss =  14.387351989746094
batch no 70 loss =  14.72762680053711
batch no 71 loss =  14.963506698608398
batch no 72 loss =  15.535669326782227
batch no 73 loss =  17.41385841369629
batch no 74 loss =  17.897266387939453
batch no 75 loss =  18.659147262573242
batch no 76 loss =  18.77376365661621
batch no 77 loss =  19.839771270751953
batch no 78 loss =  20.197492599487305
batch no 79 loss =  20.273174285888672

and this keeps increasing for gpu.

For cpu:

batch no 1 loss =  1.5018784999847412
batch no 2 loss =  1.236070990562439
batch no 3 loss =  1.094645380973816
batch no 4 loss =  0.9613571166992188
batch no 5 loss =  0.8980026841163635
batch no 6 loss =  0.7537975311279297
batch no 7 loss =  0.6927058696746826
batch no 8 loss =  0.6422103047370911
batch no 9 loss =  0.5273403525352478
batch no 10 loss =  0.4687236547470093
batch no 11 loss =  0.4736068546772003
batch no 12 loss =  0.4473775029182434
batch no 13 loss =  0.44782334566116333
batch no 14 loss =  0.3897721469402313
batch no 15 loss =  0.37255093455314636
batch no 16 loss =  0.40343356132507324
batch no 17 loss =  0.41588154435157776
batch no 18 loss =  0.3447286784648895
batch no 19 loss =  0.3894757628440857
batch no 20 loss =  0.3367255628108978
batch no 21 loss =  0.36912691593170166
batch no 22 loss =  0.3307099938392639
batch no 23 loss =  0.31340059638023376
batch no 24 loss =  0.3092980682849884
batch no 25 loss =  0.2926962673664093
batch no 26 loss =  0.30344510078430176
batch no 27 loss =  0.29070281982421875
batch no 28 loss =  0.2682488262653351
batch no 29 loss =  0.28866472840309143
batch no 30 loss =  0.2943607568740845
batch no 31 loss =  0.2776503562927246
batch no 32 loss =  0.2795364260673523
batch no 33 loss =  0.2474450021982193
batch no 34 loss =  0.24592795968055725
batch no 35 loss =  0.25801122188568115
batch no 36 loss =  0.2531129717826843
batch no 37 loss =  0.22873735427856445
batch no 38 loss =  0.2336498498916626
batch no 39 loss =  0.24088449776172638
batch no 40 loss =  0.21383072435855865
batch no 41 loss =  0.2321128398180008
batch no 42 loss =  0.21124516427516937
batch no 43 loss =  0.2158881574869156
batch no 44 loss =  0.22324740886688232
batch no 45 loss =  0.20954640209674835
batch no 46 loss =  0.18683622777462006
batch no 47 loss =  0.20678168535232544
batch no 48 loss =  0.17524667084217072
batch no 49 loss =  0.18902267515659332
batch no 50 loss =  0.19290755689144135
batch no 51 loss =  0.17181488871574402
batch no 52 loss =  0.1689278483390808
batch no 53 loss =  0.17976917326450348
batch no 54 loss =  0.1698329597711563
batch no 55 loss =  0.16441339254379272
batch no 56 loss =  0.16918373107910156
batch no 57 loss =  0.16696254909038544
batch no 58 loss =  0.1432090401649475
batch no 59 loss =  0.16007395088672638
batch no 60 loss =  0.14895042777061462
batch no 61 loss =  0.157599538564682
batch no 62 loss =  0.17643079161643982
batch no 63 loss =  0.14895623922348022
batch no 64 loss =  0.136954203248024
batch no 65 loss =  0.15611445903778076
batch no 66 loss =  0.15810908377170563
batch no 67 loss =  0.14348700642585754
batch no 68 loss =  0.12974749505519867
batch no 69 loss =  0.13568048179149628
batch no 70 loss =  0.14065878093242645
batch no 71 loss =  0.12928995490074158
batch no 72 loss =  0.14016589522361755
batch no 73 loss =  0.15469464659690857
batch no 74 loss =  0.14998523890972137
batch no 75 loss =  0.13565079867839813
batch no 76 loss =  0.1321781426668167
batch no 77 loss =  0.1252971589565277
batch no 78 loss =  0.1260882318019867
batch no 79 loss =  0.12096744030714035

Does reducing slightly the learning rate fix the issue?

Yes I just tried reducing the learning rate.
It takes a few more batches to reach the lowest batch loss and starts increasing again.

But the rate of increase of total loss is lesser.

Here’s how the output looks -

batch no 1 loss =  1.48519766330719
batch no 2 loss =  1.4672067165374756
batch no 3 loss =  1.4816129207611084
batch no 4 loss =  1.434296727180481
batch no 5 loss =  1.440385341644287
batch no 6 loss =  1.4176286458969116
batch no 7 loss =  1.40739107131958
batch no 8 loss =  1.3973239660263062
batch no 9 loss =  1.3658576011657715
batch no 10 loss =  1.3650410175323486
batch no 11 loss =  1.337887167930603
batch no 12 loss =  1.3581486940383911
batch no 13 loss =  1.3573602437973022
batch no 14 loss =  1.31741201877594
batch no 15 loss =  1.2534146308898926
batch no 16 loss =  1.3125391006469727
batch no 17 loss =  1.294553279876709
batch no 18 loss =  1.1889533996582031
batch no 19 loss =  1.2660199403762817
batch no 20 loss =  1.2492356300354004
batch no 21 loss =  1.2197800874710083
batch no 22 loss =  1.1822484731674194
batch no 23 loss =  1.1688812971115112
batch no 24 loss =  1.2103825807571411
batch no 25 loss =  1.1710189580917358
batch no 26 loss =  1.215803861618042
batch no 27 loss =  1.157057285308838
batch no 28 loss =  1.1300220489501953
batch no 29 loss =  1.0991148948669434
batch no 30 loss =  1.0601304769515991
batch no 31 loss =  1.0228008031845093
batch no 32 loss =  1.0912868976593018
batch no 33 loss =  0.9856505990028381
batch no 34 loss =  0.9819539189338684
batch no 35 loss =  1.0306510925292969
batch no 36 loss =  0.9848380088806152
batch no 37 loss =  0.989624559879303
batch no 38 loss =  0.9476594924926758
batch no 39 loss =  0.9612619876861572
batch no 40 loss =  0.8986958861351013
batch no 41 loss =  0.9100659489631653
batch no 42 loss =  0.9121372699737549
batch no 43 loss =  0.9362645745277405
batch no 44 loss =  0.8442355990409851
batch no 45 loss =  0.8276706337928772
batch no 46 loss =  0.803051233291626
batch no 47 loss =  0.8012357354164124
batch no 48 loss =  0.7831454873085022
batch no 49 loss =  0.7508223056793213
batch no 50 loss =  0.8096636533737183
batch no 51 loss =  0.7562276721000671
batch no 52 loss =  0.764530599117279
batch no 53 loss =  0.7601000666618347
batch no 54 loss =  0.710577130317688
batch no 55 loss =  0.6691561341285706
batch no 56 loss =  0.6948921084403992
batch no 57 loss =  0.6729640364646912
batch no 58 loss =  0.6475369930267334
batch no 59 loss =  0.7118080854415894
batch no 60 loss =  0.6325834393501282
batch no 61 loss =  0.6638092398643494
batch no 62 loss =  0.6681099534034729
batch no 63 loss =  0.6030600070953369
batch no 64 loss =  0.6064748167991638
batch no 65 loss =  0.5821509957313538
batch no 66 loss =  0.5825245976448059
batch no 67 loss =  0.5944026708602905
batch no 68 loss =  0.5315052270889282
batch no 69 loss =  0.5525216460227966
batch no 70 loss =  0.5441896915435791
batch no 71 loss =  0.5337881445884705
batch no 72 loss =  0.5284881591796875
batch no 73 loss =  0.5651391744613647
batch no 74 loss =  0.5703287124633789
batch no 75 loss =  0.5127482414245605
batch no 76 loss =  0.5207087397575378
batch no 77 loss =  0.5141681432723999
batch no 78 loss =  0.492698073387146
batch no 79 loss =  0.46884605288505554
batch no 80 loss =  0.511565089225769
batch no 81 loss =  0.5005688071250916
batch no 82 loss =  0.5198782682418823
batch no 83 loss =  0.47924041748046875
batch no 84 loss =  0.4710450768470764
batch no 85 loss =  0.46020805835723877
batch no 86 loss =  0.4849032163619995
batch no 87 loss =  0.4721143841743469
batch no 88 loss =  0.4703673720359802
batch no 89 loss =  0.4440164566040039
batch no 90 loss =  0.4755094647407532
batch no 91 loss =  0.4569631814956665
batch no 92 loss =  0.47981956601142883
batch no 93 loss =  0.4545298218727112
batch no 94 loss =  0.44351571798324585
batch no 95 loss =  0.46746325492858887
batch no 96 loss =  0.47358807921409607
batch no 97 loss =  0.4544294774532318
batch no 98 loss =  0.4962926506996155
batch no 99 loss =  0.48395419120788574
batch no 100 loss =  0.44108179211616516
batch no 101 loss =  0.4728832542896271
batch no 102 loss =  0.4890323281288147
batch no 103 loss =  0.4716647267341614
batch no 104 loss =  0.48249995708465576
batch no 105 loss =  0.47763603925704956
batch no 106 loss =  0.46557939052581787
batch no 107 loss =  0.4828740060329437
batch no 108 loss =  0.4845884144306183
batch no 109 loss =  0.4839400351047516
batch no 110 loss =  0.533776044845581
batch no 111 loss =  0.5043434500694275
batch no 112 loss =  0.4820551872253418
batch no 113 loss =  0.4765048027038574
batch no 114 loss =  0.5152161121368408
batch no 115 loss =  0.5205423831939697
batch no 116 loss =  0.45017296075820923
batch no 117 loss =  0.5060521364212036
batch no 118 loss =  0.5015444755554199
batch no 119 loss =  0.5357353687286377
batch no 120 loss =  0.5455232262611389
batch no 121 loss =  0.5464028716087341
batch no 122 loss =  0.5404470562934875
batch no 123 loss =  0.5515747666358948
batch no 124 loss =  0.48644503951072693
batch no 125 loss =  0.5170029401779175
batch no 126 loss =  0.4973011016845703
batch no 127 loss =  0.5444160103797913
batch no 128 loss =  0.5032432079315186
batch no 129 loss =  0.5079444646835327
batch no 130 loss =  0.5730089545249939
batch no 131 loss =  0.5962690114974976
batch no 132 loss =  0.5836448073387146
batch no 133 loss =  0.5500009655952454
batch no 134 loss =  0.5935165286064148
batch no 135 loss =  0.5874279737472534
batch no 136 loss =  0.5787685513496399
batch no 137 loss =  0.6307970285415649
batch no 138 loss =  0.6106826663017273
batch no 139 loss =  0.6551612019538879
batch no 140 loss =  0.5904539227485657
batch no 141 loss =  0.6462585926055908
batch no 142 loss =  0.6867714524269104
batch no 143 loss =  0.6145952939987183
batch no 144 loss =  0.6478555202484131
batch no 145 loss =  0.6670128107070923
batch no 146 loss =  0.6696873903274536
batch no 147 loss =  0.6415550708770752
batch no 148 loss =  0.7150164842605591
batch no 149 loss =  0.6920061707496643
batch no 150 loss =  0.7582559585571289
batch no 151 loss =  0.7134975790977478
batch no 152 loss =  0.7674281001091003
batch no 153 loss =  0.757271409034729
batch no 154 loss =  0.7957542538642883
batch no 155 loss =  0.7181404232978821
batch no 156 loss =  0.7426053881645203
batch no 157 loss =  0.7954193353652954
batch no 158 loss =  0.822559118270874
batch no 159 loss =  0.8057528734207153
batch no 160 loss =  0.8383210897445679
batch no 161 loss =  0.8529059290885925
batch no 162 loss =  0.9087692499160767
batch no 163 loss =  0.8549301624298096
batch no 164 loss =  0.8934707641601562
batch no 165 loss =  0.8587768077850342
batch no 166 loss =  0.9172695279121399
batch no 167 loss =  0.9977843165397644
batch no 168 loss =  0.9188544750213623
batch no 169 loss =  0.9683499932289124
batch no 170 loss =  1.0196045637130737
batch no 171 loss =  1.0112369060516357
batch no 172 loss =  1.0482044219970703
batch no 173 loss =  1.0166860818862915
batch no 174 loss =  1.0707794427871704
batch no 175 loss =  1.0426247119903564
batch no 176 loss =  1.1214118003845215
batch no 177 loss =  1.141948938369751
batch no 178 loss =  1.169199824333191
batch no 179 loss =  1.1278570890426636
batch no 180 loss =  1.177224040031433
batch no 181 loss =  1.1449683904647827
batch no 182 loss =  1.2180325984954834
batch no 183 loss =  1.2685843706130981
batch no 184 loss =  1.1946467161178589
batch no 185 loss =  1.2430262565612793
batch no 186 loss =  1.3517838716506958
batch no 187 loss =  1.3116604089736938
batch no 188 loss =  1.3319159746170044
batch no 189 loss =  1.331680178642273
batch no 190 loss =  1.3623106479644775
batch no 191 loss =  1.3105368614196777
batch no 192 loss =  1.3695563077926636
batch no 193 loss =  1.5167524814605713

That is quite weird…
I don’t have that much experience in LSTM training. Maybe the gradient clipping is doing unexpected things?

Yes it is super weird. I have tried removing the gradient clip. It did not make a difference.

But if there was an issue with the gradient clip, the same effect should be there for the cpu version right?

Should I add some cuda related flags while training?

I have only added a manual seed torch.manual_seed(0).

You should not need any cuda specific things.
If you run with different seeds, does the cpu version diverge sometimes also?

I ran the code partially on a 3 seed values for the cpu version and gpu version.

The batch loss seems to be decreasing for the cpu version and not for the gpu version.

I’m not sure why this could happen… Maybe @ptrblck has an idea?

What I would do is to remove part of the code until you don’t see this behavior anymore to pinpoint what is causing this. Sorry for the not so useful answer :confused:

I made the changes as you suggested. Now, I have only one learnable layer in my network.

This is what my network looks like -

scores = torch.matmul(self.linear(mention_representation), label_representation.t())

where self.linear = nn.Linear(300, 300, bias=False), mention_representation and label_representation are pretrained embeddings.

Observations:
The loss decreases to a certain point and then starts increasing, but the rate at which the loss increases is very low.

Thoughts/questions:
Do you think torch.matmul causes the issue?

Hi,
I don’t think matmul can be the problem. It is used everywhere.

Maybe other part of your code, in how you handle the labels, or how the gradient steps are done? Not sure really.

If I understand it correctly, you are using a single layer now and still observe the increase in loss for GPU, while loss decreases steadily on the CPU?
If this correct?

If so, could you just for the sake of debugging use torch.set_flush_denormal(True) for the CPU run and compare the results?
I don’t think this might be an issue, but don’t know any other difference between CPU and GPU handling.
If that still creates the issue, could you upload a checkpoint and input (if possible a single batch) so that we could debug it?

PS: Are you using the same data type (e.g. FP32) in both runs?
If so, could you try to run it with FP64?

Yes you have understood my problem.
I ran the model with torch.set_flush_denormal(True) on the cpu version and the results were the same.

I am using FP32 to train my model.

I am averaging the word embeddings inside one of the functions -

mention_embeddings = self.embeddings(mention_tokens)
mention_lengths = x['mention_tokens_length'].float().to(self.device)

average_mention_embeddings = torch.sum(mention_embeddings, 1) / mention_lengths.unsqueeze(1)

Do you think this can cause any issue?

P.S. I am using word embeddings to train the model so the checkpoint is a little large (1GB ish). Let me know if you still want it.

What shape does mention_embedding have?
Are you seeing the same issue without averaging?

If this doesn’t help, then yes, I would like to have a look at it.

The shape of mention_embeddings is [batch_size, no_of_tokens, emb_dimension].

Yes I am seeing the same issue without averaging.

I have created a folder on google drive with my checkpoints. I have also included a subset of the training data in the same folder.

Here’s the link -
https://drive.google.com/open?id=1vhMmpXB2snGF8OZmde1IDvwD9YhBdvWJ

Let me know if you need anything more!