Getting NaN values in backward pass

Hi, I am trying to implement this paper.
My implementation of the paper is here for any information about the architecture.

I am calculating the loss manually by taking a negative log of the probability. (Probability of the target sequence is calculated by the equation defined in the paper)

The error comes after few backward passes :

Warning: Traceback of forward call that caused the error:
  File "main.py", line 455, in <module>
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
  File "main.py", line 381, in train
    output = model(src, trg)
  File "/home/shreyansh/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "main.py", line 346, in forward
    output = self.decoder(trg, hidden, encoder_outputs)
  File "/home/shreyansh/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "main.py", line 308, in forward
    probabilities = self.stable_softmax(prediction)
  File "main.py", line 250, in stable_softmax
    numerator = torch.exp(z)
 (print_stack at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:57)
Traceback (most recent call last):
  File "main.py", line 455, in <module>
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
  File "main.py", line 385, in train
    loss.backward()
  File "/home/shreyansh/anaconda3/lib/python3.7/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/shreyansh/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function 'ExpBackward' returned nan values in its 0th output.

The error is coming while calculating gradients of softmax.

  def stable_softmax(self,x):
    z = x - torch.max(x,dim=1,keepdim=True).values
    numerator = torch.exp(z)
    denominator = torch.sum(numerator, dim=1, keepdims=True)
    softmax = numerator / denominator
    return softmax

I am running it on GPU.
torch.isnan(x).any() gives False tensor.
torch.isnan(z).any() gives False tensor.
So, that implies there are no NaNs in the forward pass. I am unable to figure out how NaN is coming in the backward pass.

Could you check the input to torch.exp and its output?
Maybe you are passing large values to it, so that the result might create an Inf output, which might result in a NaN in the backward pass.

These values don’t seem to be quite large, I am attaching the logs of max/min values of input and output to torch.exp

Number of training examples: 12907
Number of validation examples: 5
Number of testing examples: 25
Unique tokens in source (en) vocabulary: 2804
Unique tokens in target (hi) vocabulary: 3501
The model has 214,411 trainable parameters
Before applying exponential
max 0.0 min -0.01930026337504387
After applying exponential
max 1.0 min 0.9808847904205322
Before applying exponential
max 0.0 min -0.018911730498075485
After applying exponential
max 1.0 min 0.9812659621238708
Before applying exponential
max 0.0 min -0.019006941467523575
After applying exponential
max 1.0 min 0.9811725616455078
Before applying exponential
max 0.0 min -0.018911048769950867
After applying exponential
max 1.0 min 0.9812666177749634
Before applying exponential
max 0.0 min -0.019007515162229538
After applying exponential
max 1.0 min 0.9811719655990601
Before applying exponential
max 0.0 min -0.01864556595683098
After applying exponential
max 1.0 min 0.9815271496772766
Before applying exponential
max 0.0 min -0.02009735256433487
After applying exponential
max 1.0 min 0.9801031947135925
Before applying exponential
max 0.0 min -0.018839091062545776
After applying exponential
max 1.0 min 0.9813372492790222
Before applying exponential
max 0.0 min -0.018604880198836327
After applying exponential
max 1.0 min 0.9815670251846313
Before applying exponential
max 0.0 min -0.020035836845636368
After applying exponential
max 1.0 min 0.9801635146141052
Before applying exponential
max 0.0 min -0.019339684396982193
After applying exponential
max 1.0 min 0.980846107006073
Before applying exponential
max 0.0 min -0.020187072455883026
After applying exponential
max 1.0 min 0.9800152778625488
Before applying exponential
max 0.0 min -0.01930253952741623
After applying exponential
max 1.0 min 0.9808825254440308
Before applying exponential
max 0.0 min -0.018910178914666176
After applying exponential
max 1.0 min 0.981267511844635
Before applying exponential
max 0.0 min -0.01948946714401245
After applying exponential
max 1.0 min 0.9806991815567017
Before applying exponential
max 0.0 min -0.01900828257203102
After applying exponential
max 1.0 min 0.981171190738678
Before applying exponential
max 0.0 min -0.020097119733691216
After applying exponential
max 1.0 min 0.9801034331321716
Before applying exponential
max 0.0 min -0.01883871853351593
After applying exponential
max 1.0 min 0.9813375473022461
Before applying exponential
max 0.0 min -0.01860380545258522
After applying exponential
max 1.0 min 0.9815680980682373
Before applying exponential
max 0.0 min -0.020035073161125183
After applying exponential
max 1.0 min 0.9801642894744873
Before applying exponential
max 0.0 min -0.019338484853506088
After applying exponential
max 1.0 min 0.9808472394943237
Before applying exponential
max 0.0 min -0.020188473165035248
After applying exponential
max 1.0 min 0.980013906955719
Before applying exponential
max 0.0 min -0.019300060346722603
After applying exponential
max 1.0 min 0.9808849096298218
Before applying exponential
max 0.0 min -0.018910465762019157
After applying exponential
max 1.0 min 0.9812671542167664
Before applying exponential
max 0.0 min -0.019487164914608
After applying exponential
max 1.0 min 0.9807014465332031
Before applying exponential
max 0.0 min -0.01900639571249485
After applying exponential
max 1.0 min 0.981173038482666
Before applying exponential
max 0.0 min -0.01883944869041443
After applying exponential
max 1.0 min 0.9813368916511536
Before applying exponential
max 0.0 min -0.018602870404720306
After applying exponential
max 1.0 min 0.9815691113471985
Before applying exponential
max 0.0 min -0.020035918802022934
After applying exponential
max 1.0 min 0.9801634550094604
Before applying exponential
max 0.0 min -0.019340507686138153
After applying exponential
max 1.0 min 0.9808453321456909
Before applying exponential
max 0.0 min -0.020190240815281868
After applying exponential
max 1.0 min 0.9800121784210205
Before applying exponential
max 0.0 min -0.019304800778627396
After applying exponential
max 1.0 min 0.9808803200721741
Before applying exponential
max 0.0 min -0.018909433856606483
After applying exponential
max 1.0 min 0.9812682271003723
Before applying exponential
max 0.0 min -0.019489070400595665
After applying exponential
max 1.0 min 0.9806995987892151
Before applying exponential
max 0.0 min -0.01900819130241871
After applying exponential
max 1.0 min 0.9811713099479675
Before applying exponential
max 0.0 min -0.018603935837745667
After applying exponential
max 1.0 min 0.9815680384635925
Before applying exponential
max 0.0 min -0.02003599889576435
After applying exponential
max 1.0 min 0.9801633954048157
Before applying exponential
max 0.0 min -0.019340313971042633
After applying exponential
max 1.0 min 0.9808454513549805
Before applying exponential
max 0.0 min -0.0201878622174263
After applying exponential
max 1.0 min 0.9800145030021667
Before applying exponential
max 0.0 min -0.019302817061543465
After applying exponential
max 1.0 min 0.9808822870254517
Before applying exponential
max 0.0 min -0.018910393118858337
After applying exponential
max 1.0 min 0.9812672734260559
Before applying exponential
max 0.0 min -0.019486038014292717
After applying exponential
max 1.0 min 0.9807025790214539
Before applying exponential
max 0.0 min -0.019006924703717232
After applying exponential
max 1.0 min 0.9811725616455078
Before applying exponential
max 0.0 min -0.020037759095430374
After applying exponential
max 1.0 min 0.9801616072654724
Before applying exponential
max 0.0 min -0.019339915364980698
After applying exponential
max 1.0 min 0.9808458685874939
Before applying exponential
max 0.0 min -0.020188800990581512
After applying exponential
max 1.0 min 0.9800136685371399
Before applying exponential
max 0.0 min -0.01930314116179943
After applying exponential
max 1.0 min 0.980881929397583
Before applying exponential
max 0.0 min -0.018910465762019157
After applying exponential
max 1.0 min 0.9812671542167664
Before applying exponential
max 0.0 min -0.019486617296934128
After applying exponential
max 1.0 min 0.9807019233703613
Before applying exponential
max 0.0 min -0.019006475806236267
After applying exponential
max 1.0 min 0.9811729788780212
Before applying exponential
max 0.0 min -0.01933939754962921
After applying exponential
max 1.0 min 0.9808463454246521
Before applying exponential
max 0.0 min -0.020187366753816605
After applying exponential
max 1.0 min 0.9800150394439697
Before applying exponential
max 0.0 min -0.019302254542708397
After applying exponential
max 1.0 min 0.9808828234672546
Before applying exponential
max 0.0 min -0.018910422921180725
After applying exponential
max 1.0 min 0.9812672734260559
Before applying exponential
max 0.0 min -0.0194871723651886
After applying exponential
max 1.0 min 0.9807014465332031
Before applying exponential
max 0.0 min -0.01900804415345192
After applying exponential
max 1.0 min 0.9811714291572571
Before applying exponential
max 0.0 min -0.02018953673541546
After applying exponential
max 1.0 min 0.9800128936767578
Before applying exponential
max 0.0 min -0.01930280774831772
After applying exponential
max 1.0 min 0.9808822870254517
Before applying exponential
max 0.0 min -0.0189106035977602
After applying exponential
max 1.0 min 0.9812670946121216
Before applying exponential
max 0.0 min -0.01948753371834755
After applying exponential
max 1.0 min 0.9807010293006897
Before applying exponential
max 0.0 min -0.019007209688425064
After applying exponential
max 1.0 min 0.9811722040176392
Before applying exponential
max 0.0 min -0.019301380962133408
After applying exponential
max 1.0 min 0.9808836579322815
Before applying exponential
max 0.0 min -0.018910270184278488
After applying exponential
max 1.0 min 0.9812673926353455
Before applying exponential
max 0.0 min -0.019488558173179626
After applying exponential
max 1.0 min 0.9807000756263733
Before applying exponential
max 0.0 min -0.01900819130241871
After applying exponential
max 1.0 min 0.9811713099479675
Before applying exponential
max 0.0 min -0.018910503014922142
After applying exponential
max 1.0 min 0.9812671542167664
Before applying exponential
max 0.0 min -0.019488057121634483
After applying exponential
max 1.0 min 0.9807005524635315
Before applying exponential
max 0.0 min -0.019006900489330292
After applying exponential
max 1.0 min 0.9811725616455078
Before applying exponential
max 0.0 min -0.019488198682665825
After applying exponential
max 1.0 min 0.9807004332542419
Before applying exponential
max 0.0 min -0.01900586299598217
After applying exponential
max 1.0 min 0.981173574924469
Before applying exponential
max 0.0 min -0.018645640462636948
After applying exponential
max 1.0 min 0.9815270900726318
Before applying exponential
max 0.0 min -0.020098166540265083
After applying exponential
max 1.0 min 0.9801024198532104
Before applying exponential
max 0.0 min -0.018839359283447266
After applying exponential
max 1.0 min 0.9813370108604431
Before applying exponential
max 0.0 min -0.018604330718517303
After applying exponential
max 1.0 min 0.9815676212310791
Before applying exponential
max 0.0 min -0.020036987960338593
After applying exponential
max 1.0 min 0.9801623821258545
Before applying exponential
max 0.0 min -0.019339658319950104
After applying exponential
max 1.0 min 0.980846107006073
Before applying exponential
max 0.0 min -0.02018718235194683
After applying exponential
max 1.0 min 0.9800151586532593
Before applying exponential
max 0.0 min -0.019302602857351303
After applying exponential
max 1.0 min 0.9808824062347412
Before applying exponential
max 0.0 min -0.01891184411942959
After applying exponential
max 1.0 min 0.9812658429145813
Before applying exponential
max 0.0 min -0.019488999620079994
After applying exponential
max 1.0 min 0.9806996583938599
Before applying exponential
max 0.0 min -0.01982579194009304
After applying exponential
max 1.0 min 0.9803694486618042
Before applying exponential
max 0.0 min -0.019007962197065353
After applying exponential
max 1.0 min 0.9811714887619019
Before applying exponential
max 0.0 min -0.02009640820324421
After applying exponential
max 1.0 min 0.9801041483879089
Before applying exponential
max 0.0 min -0.01883961260318756
After applying exponential
max 1.0 min 0.9813367128372192
Before applying exponential
max 0.0 min -0.01860237866640091
After applying exponential
max 1.0 min 0.9815695881843567
Before applying exponential
max 0.0 min -0.02003667876124382
After applying exponential
max 1.0 min 0.9801627397537231
Before applying exponential
max 0.0 min -0.0193393062800169
After applying exponential
max 1.0 min 0.9808464646339417
Before applying exponential
max 0.0 min -0.020189685747027397
After applying exponential
max 1.0 min 0.9800127744674683
Before applying exponential
max 0.0 min -0.01930147223174572
After applying exponential
max 1.0 min 0.9808835387229919
Before applying exponential
max 0.0 min -0.0189093928784132
After applying exponential
max 1.0 min 0.9812682271003723
Before applying exponential
max 0.0 min -0.01948663219809532
After applying exponential
max 1.0 min 0.9807019233703613
Before applying exponential
max 0.0 min -0.019825227558612823
After applying exponential
max 1.0 min 0.9803699851036072
Before applying exponential
max 0.0 min -0.01900714449584484
After applying exponential
max 1.0 min 0.9811723232269287
Before applying exponential
max 0.0 min -0.018840152770280838
After applying exponential
max 1.0 min 0.9813361763954163
Before applying exponential
max 0.0 min -0.018603524193167686
After applying exponential
max 1.0 min 0.981568455696106
Before applying exponential
max 0.0 min -0.020036067813634872
After applying exponential
max 1.0 min 0.9801632761955261
Before applying exponential
max 0.0 min -0.019339565187692642
After applying exponential
max 1.0 min 0.9808462262153625
Before applying exponential
max 0.0 min -0.020190367475152016
After applying exponential
max 1.0 min 0.9800121188163757
Before applying exponential
max 0.0 min -0.019302181899547577
After applying exponential
max 1.0 min 0.9808829426765442
Before applying exponential
max 0.0 min -0.018910503014922142
After applying exponential
max 1.0 min 0.9812671542167664
Before applying exponential
max 0.0 min -0.01948685571551323
After applying exponential
max 1.0 min 0.9807018041610718
Before applying exponential
max 0.0 min -0.019825518131256104
After applying exponential
max 1.0 min 0.9803696870803833
Before applying exponential
max 0.0 min -0.019008396193385124
After applying exponential
max 1.0 min 0.9811710715293884
Before applying exponential
max 0.0 min -0.01860380545258522
After applying exponential
max 1.0 min 0.9815680980682373
Before applying exponential
max 0.0 min -0.020034171640872955
After applying exponential
max 1.0 min 0.9801651835441589
Before applying exponential
max 0.0 min -0.01933983899652958
After applying exponential
max 1.0 min 0.9808459281921387
Before applying exponential
max 0.0 min -0.020188678056001663
After applying exponential
max 1.0 min 0.9800137281417847
Before applying exponential
max 0.0 min -0.01930253580212593
After applying exponential
max 1.0 min 0.9808825254440308
Before applying exponential
max 0.0 min -0.018911192193627357
After applying exponential
max 1.0 min 0.981266438961029
Before applying exponential
max 0.0 min -0.019490612670779228
After applying exponential
max 1.0 min 0.9806980490684509
Before applying exponential
max 0.0 min -0.019825948402285576
After applying exponential
max 1.0 min 0.9803692698478699
Before applying exponential
max 0.0 min -0.01900610886514187
After applying exponential
max 1.0 min 0.9811733365058899
Before applying exponential
max 0.0 min -0.02003747597336769
After applying exponential
max 1.0 min 0.9801619052886963
Before applying exponential
max 0.0 min -0.019340286031365395
After applying exponential
max 1.0 min 0.9808454513549805
Before applying exponential
max 0.0 min -0.02018577791750431
After applying exponential
max 1.0 min 0.9800165295600891
Before applying exponential
max 0.0 min -0.019299039617180824
After applying exponential
max 1.0 min 0.9808859825134277
Before applying exponential
max 0.0 min -0.018910575658082962
After applying exponential
max 1.0 min 0.9812670946121216
Before applying exponential
max 0.0 min -0.019486989825963974
After applying exponential
max 1.0 min 0.9807016253471375
Before applying exponential
max 0.0 min -0.01982567459344864
After applying exponential
max 1.0 min 0.980369508266449
Before applying exponential
max 0.0 min -0.019006211310625076
After applying exponential
max 1.0 min 0.9811732172966003
Before applying exponential
max 0.0 min -0.019340457394719124
After applying exponential
max 1.0 min 0.9808453321456909
Before applying exponential
max 0.0 min -0.020189344882965088
After applying exponential
max 1.0 min 0.9800130724906921
Before applying exponential
max 0.0 min -0.01930450275540352
After applying exponential
max 1.0 min 0.9808805584907532
Before applying exponential
max 0.0 min -0.01891041174530983
After applying exponential
max 1.0 min 0.9812672734260559
Before applying exponential
max 0.0 min -0.019486669450998306
After applying exponential
max 1.0 min 0.9807019233703613
Before applying exponential
max 0.0 min -0.019827280193567276
After applying exponential
max 1.0 min 0.9803679585456848
Before applying exponential
max 0.0 min -0.019005414098501205
After applying exponential
max 1.0 min 0.9811739921569824
Before applying exponential
max 0.0 min -0.02018619142472744
After applying exponential
max 1.0 min 0.9800161719322205
Before applying exponential
max 0.0 min -0.019303616136312485
After applying exponential
max 1.0 min 0.9808814525604248
Before applying exponential
max 0.0 min -0.0189113337546587
After applying exponential
max 1.0 min 0.9812663793563843
Before applying exponential
max 0.0 min -0.019487900659441948
After applying exponential
max 1.0 min 0.9807007312774658
Before applying exponential
max 0.0 min -0.019827045500278473
After applying exponential
max 1.0 min 0.9803681969642639
Before applying exponential
max 0.0 min -0.019005149602890015
After applying exponential
max 1.0 min 0.9811742305755615
Before applying exponential
max 0.0 min -0.01930542290210724
After applying exponential
max 1.0 min 0.9808796644210815
Before applying exponential
max 0.0 min -0.018911056220531464
After applying exponential
max 1.0 min 0.9812666177749634
Before applying exponential
max 0.0 min -0.019487980753183365
After applying exponential
max 1.0 min 0.980700671672821
Before applying exponential
max 0.0 min -0.019826088100671768
After applying exponential
max 1.0 min 0.9803690910339355
Before applying exponential
max 0.0 min -0.019005481153726578
After applying exponential
max 1.0 min 0.9811739325523376
Before applying exponential
max 0.0 min -0.018912117928266525
After applying exponential
max 1.0 min 0.9812655448913574
Before applying exponential
max 0.0 min -0.019485829398036003
After applying exponential
max 1.0 min 0.9807027578353882
Before applying exponential
max 0.0 min -0.019825410097837448
After applying exponential
max 1.0 min 0.9803697466850281
Before applying exponential
max 0.0 min -0.019007179886102676
After applying exponential
max 1.0 min 0.9811723232269287
Before applying exponential
max 0.0 min -0.019486064091324806
After applying exponential
max 1.0 min 0.9807025194168091
Before applying exponential
max 0.0 min -0.019826460629701614
After applying exponential
max 1.0 min 0.9803687930107117
Before applying exponential
max 0.0 min -0.019005997106432915
After applying exponential
max 1.0 min 0.9811734557151794
Before applying exponential
max 0.0 min -0.01982627622783184
After applying exponential
max 1.0 min 0.980368971824646
loss is tensor(97.8701, device='cuda:0', grad_fn=<NegBackward>)

You are right, these values look alright, but they also don’t produce the NaN issue, right?
Could you check the values for a couple of more iterations until you encounter the first NaN value?

Unfortunately, the code breaks in this iteration itself during back propagation.

Before applying exponential
max 0.0 min -0.01982627622783184
After applying exponential
max 1.0 min 0.980368971824646
loss is tensor(97.8701, device='cuda:0', grad_fn=<NegBackward>)
Warning: Traceback of forward call that caused the error:
  File "main.py", line 468, in <module>
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
  File "main.py", line 386, in train
    output = model(src, trg)
  File "/home/shreyansh/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "main.py", line 351, in forward
    output = self.decoder(trg, hidden, encoder_outputs)
  File "/home/shreyansh/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "main.py", line 313, in forward
    probabilities = self.stable_softmax(prediction)
  File "main.py", line 253, in stable_softmax
    numerator = torch.exp(z)
 (print_stack at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:57)
Traceback (most recent call last):
  File "main.py", line 468, in <module>
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
  File "main.py", line 394, in train
    loss.backward()
  File "/home/shreyansh/anaconda3/lib/python3.7/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/shreyansh/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function 'ExpBackward' returned nan values in its 0th output.

Thanks for the update.
Could you check the denominator as well and make sure that it’s not too small so that the result might overflow?

Yes, sure. These are the max/min values of the denominator before getting an error.

max 3466.501953125 min 3464.61669921875
max 3466.2783203125 min 3462.81103515625
max 3466.5009765625 min 3464.6171875
max 3466.274658203125 min 3462.81591796875
max 3464.37353515625 min 3463.282470703125
max 3466.50537109375 min 3464.61962890625
max 3464.37109375 min 3463.28759765625
max 3466.51123046875 min 3464.6162109375
max 3466.273193359375 min 3462.81494140625
max 3464.3701171875 min 3463.287109375
max 3466.9697265625 min 3462.463623046875
max 3466.506591796875 min 3464.61962890625
max 3464.37255859375 min 3463.28173828125
max 3466.96875 min 3462.469482421875
max 3466.50732421875 min 3464.61767578125
max 3466.97021484375 min 3462.467041015625
max 3466.507080078125 min 3464.619873046875
max 3466.27734375 min 3462.8134765625
max 3464.37109375 min 3463.281982421875
max 3466.966796875 min 3462.467529296875
max 3468.5478515625 min 3464.38232421875
max 3466.505859375 min 3464.6142578125
max 3464.3720703125 min 3463.286865234375
max 3466.9716796875 min 3462.46923828125
max 3468.5419921875 min 3464.3779296875
max 3466.51025390625 min 3464.61572265625
max 3466.9677734375 min 3462.47021484375
max 3468.54345703125 min 3464.37744140625
max 3466.50634765625 min 3464.61865234375
max 3468.544677734375 min 3464.38330078125
max 3466.50244140625 min 3464.6181640625
max 3466.27099609375 min 3462.8115234375
max 3464.37548828125 min 3463.283935546875
max 3466.9716796875 min 3462.466064453125
max 3468.54248046875 min 3464.383544921875
max 3464.380615234375 min 3463.459716796875
max 3466.50634765625 min 3464.6181640625
max 3464.3740234375 min 3463.28271484375
max 3466.970703125 min 3462.4658203125
max 3468.54541015625 min 3464.37841796875
max 3464.37890625 min 3463.457763671875
max 3466.5078125 min 3464.6123046875
max 3466.970703125 min 3462.46826171875
max 3468.541748046875 min 3464.382568359375
max 3464.381591796875 min 3463.4609375
max 3466.505859375 min 3464.61669921875
max 3468.5390625 min 3464.3798828125
max 3464.38330078125 min 3463.45654296875
max 3466.50537109375 min 3464.61669921875
max 3464.377197265625 min 3463.455078125
max 3466.50390625 min 3464.615234375
max 3466.273681640625 min 3462.81298828125
max 3464.375 min 3463.28759765625
max 3466.964599609375 min 3462.46630859375
max 3468.542724609375 min 3464.37841796875
max 3464.3818359375 min 3463.45703125
max 3466.3603515625 min 3464.38037109375
max 3466.50537109375 min 3464.6181640625
max 3464.376953125 min 3463.2841796875
max 3466.96728515625 min 3462.462890625
max 3468.547119140625 min 3464.383056640625
max 3464.37841796875 min 3463.4580078125
max 3466.356201171875 min 3464.379638671875
max 3466.503662109375 min 3464.61474609375
max 3466.97119140625 min 3462.468505859375
max 3468.54248046875 min 3464.38232421875
max 3464.38232421875 min 3463.4599609375
max 3466.36083984375 min 3464.37939453125
max 3466.501953125 min 3464.62109375
max 3468.546875 min 3464.38037109375
max 3464.380859375 min 3463.458984375
max 3466.359130859375 min 3464.38134765625
max 3466.505859375 min 3464.61767578125
max 3464.380859375 min 3463.4541015625
max 3466.3515625 min 3464.379150390625
max 3466.506103515625 min 3464.614990234375
max 3466.35595703125 min 3464.380859375
max 3466.503173828125 min 3464.6171875
max 3466.27001953125 min 3462.81787109375
max 3464.369873046875 min 3463.279052734375
max 3466.9697265625 min 3462.46630859375
max 3468.541259765625 min 3464.378662109375
max 3464.3818359375 min 3463.45703125
max 3466.35888671875 min 3464.3828125
max 3465.03955078125 min 3464.380615234375
max 3466.50341796875 min 3464.61572265625
max 3464.373046875 min 3463.2841796875
max 3466.967529296875 min 3462.471435546875
max 3468.5419921875 min 3464.383544921875
max 3464.383544921875 min 3463.459716796875
max 3466.35791015625 min 3464.38134765625
max 3465.0439453125 min 3464.376220703125
max 3466.50341796875 min 3464.61572265625
max 3466.969482421875 min 3462.46630859375
max 3468.545166015625 min 3464.3857421875
max 3464.37841796875 min 3463.458984375
max 3466.35595703125 min 3464.379150390625
max 3465.04296875 min 3464.3798828125
max 3466.5078125 min 3464.61572265625
max 3468.539794921875 min 3464.3779296875
max 3464.376708984375 min 3463.46044921875
max 3466.359375 min 3464.378662109375
max 3465.0478515625 min 3464.376708984375
max 3466.50830078125 min 3464.61669921875
max 3464.377685546875 min 3463.4560546875
max 3466.3583984375 min 3464.381591796875
max 3465.04541015625 min 3464.3818359375
max 3466.50537109375 min 3464.6162109375
max 3466.352783203125 min 3464.381591796875
max 3465.0400390625 min 3464.379638671875
max 3466.505615234375 min 3464.6142578125
max 3465.03955078125 min 3464.38330078125
max 3466.5087890625 min 3464.6181640625
max 3466.27587890625 min 3462.8134765625
max 3464.370361328125 min 3463.28271484375
max 3466.96826171875 min 3462.46728515625
max 3468.545654296875 min 3464.382568359375
max 3464.380859375 min 3463.45654296875
max 3466.36083984375 min 3464.37939453125
max 3465.04931640625 min 3464.383056640625
max 3464.953125 min 3464.38037109375
max 3466.5029296875 min 3464.61767578125
max 3464.37109375 min 3463.28271484375
max 3466.9697265625 min 3462.4658203125
max 3468.545654296875 min 3464.38037109375
max 3464.383056640625 min 3463.455810546875
max 3466.3564453125 min 3464.37841796875
max 3465.0380859375 min 3464.3818359375
max 3464.95654296875 min 3464.376708984375
max 3466.50634765625 min 3464.614501953125
max 3466.96728515625 min 3462.470947265625
max 3468.54150390625 min 3464.3779296875
max 3464.383056640625 min 3463.4609375
max 3466.358154296875 min 3464.37939453125
max 3465.039306640625 min 3464.376953125
max 3464.9580078125 min 3464.38427734375
max 3466.49951171875 min 3464.61376953125
max 3468.54345703125 min 3464.3837890625
max 3464.381591796875 min 3463.45751953125
max 3466.358154296875 min 3464.3828125
max 3465.03857421875 min 3464.3798828125
max 3464.95361328125 min 3464.381103515625
max 3466.5009765625 min 3464.61572265625
max 3464.382568359375 min 3463.45263671875
max 3466.357666015625 min 3464.383056640625
max 3465.04296875 min 3464.376953125
max 3464.9599609375 min 3464.379150390625
max 3466.50732421875 min 3464.61572265625
max 3466.360107421875 min 3464.382568359375
max 3465.04443359375 min 3464.378662109375
max 3464.95947265625 min 3464.380859375
max 3466.506103515625 min 3464.6171875
max 3465.043701171875 min 3464.38134765625
max 3464.95849609375 min 3464.380126953125
max 3466.50341796875 min 3464.6171875
max 3464.95849609375 min 3464.383544921875
max 3466.50439453125 min 3464.61181640625
max 3466.27490234375 min 3462.81494140625
max 3464.37158203125 min 3463.2861328125
max 3466.9716796875 min 3462.463623046875
max 3468.54638671875 min 3464.376708984375
max 3464.38232421875 min 3463.4560546875
max 3466.359375 min 3464.38330078125
max 3465.0458984375 min 3464.3837890625
max 3464.957763671875 min 3464.3779296875
max 3466.45751953125 min 3464.38427734375
max 3466.5 min 3464.6181640625
max 3464.3759765625 min 3463.285400390625
max 3466.97021484375 min 3462.466552734375
max 3468.541015625 min 3464.37548828125
max 3464.3828125 min 3463.4560546875
max 3466.357177734375 min 3464.378173828125
max 3465.0439453125 min 3464.3837890625
max 3464.955322265625 min 3464.385498046875
max 3466.4580078125 min 3464.3779296875
max 3466.5029296875 min 3464.6181640625
max 3466.9716796875 min 3462.46337890625
max 3468.54296875 min 3464.380859375
max 3464.380126953125 min 3463.455078125
max 3466.3623046875 min 3464.384033203125
max 3465.044921875 min 3464.37890625
max 3464.95947265625 min 3464.380859375
max 3466.45556640625 min 3464.385009765625
max 3466.5048828125 min 3464.61572265625
max 3468.544921875 min 3464.384765625
max 3464.380859375 min 3463.45849609375
max 3466.35986328125 min 3464.376708984375
max 3465.04443359375 min 3464.378662109375
max 3464.953125 min 3464.379150390625
max 3466.45849609375 min 3464.3828125
max 3466.5048828125 min 3464.61767578125
max 3464.379150390625 min 3463.45947265625
max 3466.357421875 min 3464.3828125
max 3465.045654296875 min 3464.378173828125
max 3464.9609375 min 3464.380126953125
max 3466.4560546875 min 3464.380859375
max 3466.5009765625 min 3464.615234375
max 3466.36279296875 min 3464.377685546875
max 3465.045654296875 min 3464.3818359375
max 3464.95654296875 min 3464.385986328125
max 3466.45947265625 min 3464.38232421875
max 3466.506591796875 min 3464.617431640625
max 3465.04248046875 min 3464.385498046875
max 3464.9619140625 min 3464.382080078125
max 3466.45703125 min 3464.381103515625
max 3466.50830078125 min 3464.613525390625
max 3464.959716796875 min 3464.379638671875
max 3466.456298828125 min 3464.37646484375
max 3466.504638671875 min 3464.61669921875
max 3466.45556640625 min 3464.383056640625
max 3466.50341796875 min 3464.618408203125
max 3466.274658203125 min 3462.815185546875
max 3464.371826171875 min 3463.285888671875
max 3466.9697265625 min 3462.462646484375
max 3468.54248046875 min 3464.377197265625
max 3464.37646484375 min 3463.456787109375
max 3466.3583984375 min 3464.37939453125
max 3465.04443359375 min 3464.37890625
max 3464.95556640625 min 3464.378173828125
max 3466.45849609375 min 3464.38330078125
max 3465.380126953125 min 3464.380859375
max 3466.500244140625 min 3464.61767578125
max 3464.3740234375 min 3463.285888671875
max 3466.97412109375 min 3462.466552734375
max 3468.547119140625 min 3464.37646484375
max 3464.3818359375 min 3463.45947265625
max 3466.361083984375 min 3464.380126953125
max 3465.04345703125 min 3464.3818359375
max 3464.96142578125 min 3464.377685546875
max 3466.458251953125 min 3464.377685546875
max 3465.383544921875 min 3464.376953125
max 3466.50634765625 min 3464.61376953125
max 3466.97119140625 min 3462.464111328125
max 3468.543212890625 min 3464.3837890625
max 3464.37939453125 min 3463.4580078125
max 3466.35791015625 min 3464.3818359375
max 3465.039794921875 min 3464.382568359375
max 3464.952880859375 min 3464.381591796875
max 3466.461181640625 min 3464.385498046875
max 3465.382080078125 min 3464.38330078125
max 3466.501220703125 min 3464.614501953125
max 3468.54296875 min 3464.381591796875
max 3464.3798828125 min 3463.45654296875
max 3466.35693359375 min 3464.380126953125
max 3465.04345703125 min 3464.382080078125
max 3464.955078125 min 3464.37890625
max 3466.459228515625 min 3464.380859375
max 3465.386474609375 min 3464.38134765625
max 3466.505859375 min 3464.6162109375
max 3464.38427734375 min 3463.45458984375
max 3466.35888671875 min 3464.38427734375
max 3465.0419921875 min 3464.376708984375
max 3464.955078125 min 3464.376220703125
max 3466.458251953125 min 3464.37890625
max 3465.384033203125 min 3464.380126953125
max 3466.50537109375 min 3464.61767578125
max 3466.358642578125 min 3464.378662109375
max 3465.04443359375 min 3464.380615234375
max 3464.9580078125 min 3464.38330078125
max 3466.45849609375 min 3464.37890625
max 3465.38427734375 min 3464.3759765625
max 3466.5009765625 min 3464.6162109375
max 3465.039306640625 min 3464.3828125
max 3464.95654296875 min 3464.379150390625
max 3466.45654296875 min 3464.378662109375
max 3465.383544921875 min 3464.383544921875
max 3466.503662109375 min 3464.619140625
max 3464.96044921875 min 3464.3837890625
max 3466.45849609375 min 3464.382080078125
max 3465.381103515625 min 3464.384765625
max 3466.501220703125 min 3464.6171875
max 3466.459716796875 min 3464.380126953125
max 3465.38232421875 min 3464.38232421875
max 3466.50244140625 min 3464.61767578125
max 3465.3828125 min 3464.377197265625
max 3466.5078125 min 3464.614013671875
max 3466.27490234375 min 3462.81884765625
max 3464.3720703125 min 3463.287353515625
max 3466.97265625 min 3462.46435546875
max 3468.542236328125 min 3464.3818359375
max 3464.3798828125 min 3463.45654296875
max 3466.35986328125 min 3464.38037109375
max 3465.044677734375 min 3464.377685546875
max 3464.956787109375 min 3464.376708984375
max 3466.454345703125 min 3464.3828125
max 3465.382568359375 min 3464.379638671875
max 3464.5341796875 min 3464.37744140625
max 3466.501953125 min 3464.61767578125
max 3464.373046875 min 3463.284423828125
max 3466.9697265625 min 3462.462646484375
max 3468.543701171875 min 3464.38525390625
max 3464.378173828125 min 3463.45654296875
max 3466.359130859375 min 3464.382080078125
max 3465.03955078125 min 3464.382568359375
max 3464.957275390625 min 3464.38134765625
max 3466.459716796875 min 3464.377197265625
max 3465.38623046875 min 3464.380615234375
max 3464.535400390625 min 3464.378173828125
max 3466.50341796875 min 3464.613525390625
max 3466.970703125 min 3462.4658203125
max 3468.54296875 min 3464.3837890625
max 3464.384033203125 min 3463.45654296875
max 3466.3583984375 min 3464.380859375
max 3465.0400390625 min 3464.379150390625
max 3464.9580078125 min 3464.383056640625
max 3466.458984375 min 3464.37841796875
max 3465.386474609375 min 3464.38037109375
max 3464.535888671875 min 3464.3798828125
max 3466.5009765625 min 3464.6162109375
max 3468.545654296875 min 3464.383056640625
max 3464.3818359375 min 3463.46142578125
max 3466.35693359375 min 3464.382568359375
max 3465.041015625 min 3464.38037109375
max 3464.95751953125 min 3464.3828125
max 3466.456298828125 min 3464.384521484375
max 3465.37841796875 min 3464.37890625
max 3464.534423828125 min 3464.38525390625
max 3466.5068359375 min 3464.6162109375
max 3464.379150390625 min 3463.455810546875
max 3466.3583984375 min 3464.38134765625
max 3465.0478515625 min 3464.379150390625
max 3464.962890625 min 3464.3818359375
max 3466.457275390625 min 3464.3818359375
max 3465.385009765625 min 3464.376953125
max 3464.5341796875 min 3464.3779296875
max 3466.50634765625 min 3464.6201171875
max 3466.35791015625 min 3464.38037109375
max 3465.0419921875 min 3464.385498046875
max 3464.95458984375 min 3464.380859375
max 3466.45849609375 min 3464.380126953125
max 3465.38525390625 min 3464.38525390625
max 3464.531982421875 min 3464.3759765625
max 3466.5087890625 min 3464.6171875
max 3465.045654296875 min 3464.3876953125
max 3464.95361328125 min 3464.384033203125
max 3466.4560546875 min 3464.38037109375
max 3465.3828125 min 3464.381591796875
max 3464.531005859375 min 3464.37890625
max 3466.50830078125 min 3464.61474609375
max 3464.9521484375 min 3464.37646484375
max 3466.45751953125 min 3464.38134765625
max 3465.383056640625 min 3464.382568359375
max 3464.534423828125 min 3464.38427734375
max 3466.507080078125 min 3464.61767578125
max 3466.45361328125 min 3464.383056640625
max 3465.387939453125 min 3464.38623046875
max 3464.534423828125 min 3464.380615234375
max 3466.502685546875 min 3464.6171875
max 3465.38623046875 min 3464.3818359375
max 3464.5322265625 min 3464.380859375
max 3466.506103515625 min 3464.61474609375
max 3464.533447265625 min 3464.3759765625
loss is tensor(97.8701, device='cuda:0', grad_fn=<NegBackward>)
Warning: Traceback of forward call that caused the error:
  File "main.py", line 465, in <module>
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
  File "main.py", line 383, in train
    output = model(src, trg)
  File "/home/shreyansh/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "main.py", line 348, in forward
    output = self.decoder(trg, hidden, encoder_outputs)
  File "/home/shreyansh/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "main.py", line 310, in forward
    probabilities = self.stable_softmax(prediction)
  File "main.py", line 251, in stable_softmax
    numerator = torch.exp(z)
 (print_stack at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:57)
Traceback (most recent call last):
  File "main.py", line 465, in <module>
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
  File "main.py", line 391, in train
    loss.backward()
  File "/home/shreyansh/anaconda3/lib/python3.7/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/shreyansh/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function 'ExpBackward' returned nan values in its 0th output.

Thanks for the update. Still unsure, where this might be coming from.
Could you check all parameters and gradients for invalid values?
You could use torch.isfinite(tensor) to check for valid values.

Alright, I have printed this and the output seems fine(all tensors are valid).

    for param in model.parameters():
      print("param.data",torch.isfinite(param.data).all())
      print("param.grad.data",torch.isfinite(param.grad.data).all(),"\n")

Please tell me if you mean to see something else.

After optimization step

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

param.data tensor(True, device='cuda:0')
param.grad.data tensor(True, device='cuda:0') 

loss is tensor(97.8701, device='cuda:0', grad_fn=<NegBackward>)
Warning: Traceback of forward call that caused the error:
  File "main.py", line 476, in <module>
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
  File "main.py", line 383, in train
    output = model(src, trg)
  File "/home/shreyansh/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "main.py", line 348, in forward
    output = self.decoder(trg, hidden, encoder_outputs)
  File "/home/shreyansh/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "main.py", line 310, in forward
    probabilities = self.stable_softmax(prediction)
  File "main.py", line 251, in stable_softmax
    numerator = torch.exp(z)
 (print_stack at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:57)
Traceback (most recent call last):
  File "main.py", line 476, in <module>
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
  File "main.py", line 391, in train
    loss.backward()
  File "/home/shreyansh/anaconda3/lib/python3.7/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/shreyansh/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function 'ExpBackward' returned nan values in its 0th output.

Thanks for the great debugging so far.
Are you able to reproduce this issue using random input data?
If so, could you post the complete code including all seeds etc. which would reproduce this issue please, so that we can debug further?

I am pasting the code here, alternatively, the code and data can be cloned from here.
The model won’t work well for longer sequences as it takes the product of probabilities, however, I am getting the error even when I am testing on a smaller sequence length (say 5). If length is restricted on random input data then the error can be reproduced.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchtext.datasets import TranslationDataset
from torchtext.data import Field, BucketIterator

import spacy
import numpy as np

import random
import math
import time

SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.autograd.set_detect_anomaly(True)

spacy_en = spacy.load('en')

def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings (tokens) and reverses it
    """
    return [tok.text for tok in spacy_en.tokenizer(text)][::-1]

def tokenize_hi(text):
    """
    Tokenizes Hindi text from a string into a list of strings (tokens) 
    """
    return text.split()

SRC = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)

TRG = Field(tokenize = tokenize_hi, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)

train_data, valid_data, test_data  = TranslationDataset.splits(
                                      path='IITB_small',
                                      validation='dev',
                                      exts = ('.en', '.hi'), 
                                      fields = (SRC, TRG))

print(f"Number of training examples: {len(train_data.examples)}")
print(f"Number of validation examples: {len(valid_data.examples)}")
print(f"Number of testing examples: {len(test_data.examples)}")

vars(train_data.examples[0])

SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2,specials=['<pad>','<sop>','<eop>'])

print(f"Unique tokens in source (en) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (hi) vocabulary: {len(TRG.vocab)}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

device 

BATCH_SIZE = 2

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_size = BATCH_SIZE, 
    device = device)

"""# EnCoder Parameters"""

input_dim = len(SRC.vocab)
embed_dim = 10
hidden_dim = 10
segment_dim = 10
n_layers = 6
dropout = 0.4
segment_threshold = 5
temperature = 0.1

"""# Building Encoder"""

class Encoder(nn.Module):
  def __init__(self,input_dim,embed_dim,hidden_dim,segment_dim,n_layers,dropout,segment_threshold,device):
    super().__init__()
    self.input_dim = input_dim
    self.hidden_dim = hidden_dim
    self.n_layers = n_layers
    self.segment_threshold = segment_threshold
    self.segment_dim = segment_dim
    self.device = device
    
    self.embedding = nn.Embedding(input_dim,embed_dim)
    self.rnn = nn.GRU(embed_dim,hidden_dim,n_layers,dropout=dropout,bidirectional=True)

    self.segmentRnn = nn.GRU(hidden_dim*2,segment_dim,n_layers,dropout=dropout)
    self.fc = nn.Linear(hidden_dim*2,hidden_dim)
    self.dropout = nn.Dropout(dropout)

  def forward(self,input):

    #input = [src len, batch size]
    embedded = self.dropout(self.embedding(input))
    #embedded = [src len, batch size, emb dim]

    outputs, hidden = self.rnn(embedded)
    #outputs = [src len, batch size, hid dim * num directions]
    #hidden = [n layers * num directions, batch size, hid dim]
        
    segment_encoding, hidden = self.segment_rnn(outputs)
    #segment_encoding = [src len* (src len+1)/2, batch size, segment_dim*num_directions]
    #hidden = [n layers * num_directions, batch size, hid dim]

    # hidden = torch.tanh(self.fc(torch.cat((hidden[-2],hidden[-1]),dim=1)))

    return segment_encoding,hidden

  def segment_rnn(self,outputs):
    N = outputs.shape[0]
    batch_size = outputs.shape[1]
    dp_forward = torch.zeros(N, N, batch_size, self.segment_dim).to(self.device)
    dp_backward = torch.zeros(N, N, batch_size, self.segment_dim).to(self.device)

    for i in range(N):
      hidden_forward = torch.randn(self.n_layers, batch_size, self.hidden_dim).to(self.device)
      for j in range(i, min(N, i + self.segment_threshold)):
        
        # outputs[j] = [batch size, hidden_dim* num_direction]
        next_input = outputs[j].unsqueeze(0)
        # next_input = [1, batch size, hidden_dim* num_direction]
        
        out, hidden_forward = self.segmentRnn(next_input,hidden_forward)
        #out = [1, batch size, segment_dim]
        #hidden_forward = [n layers , batch size, hid dim]

        dp_forward[i][j] = out.squeeze(0)

    for i in range(N):
      hidden_backward = torch.randn(self.n_layers, batch_size, self.hidden_dim).to(self.device)
      for j in range(i, max(-1, i - self.segment_threshold), -1):

        # outputs[j] = [batch size, hidden_dim* num_direction]
        next_input = outputs[j].unsqueeze(0)
        # next_input = [1, batch size, hidden_dim* num_direction]
        
        out, hidden_backward = self.segmentRnn(next_input,hidden_backward)
        #out = [1, batch size, segment_dim]
        #hidden_backward = [n layers , batch size, hid dim]
        
        dp_backward[j][i] = out.squeeze(0)
    
    dp = torch.cat((dp_forward,dp_backward),dim=3)
    dp_indices = torch.triu_indices(N, N)
    dp = dp[dp_indices[0],dp_indices[1]]
    return dp,torch.cat((hidden_forward,hidden_backward),dim=2)

"""# Defining Attn Network"""
'''
Attention is calculated over encoder_outputs S(i,j) and context representation
of previously generated segments (from Target Decoder)
'''
class Attention(nn.Module):
  def __init__(self, enc_hid_dim, dec_hid_dim):
    super().__init__()

    self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
    self.v = nn.Linear(dec_hid_dim, 1, bias = False)

  def forward(self, encoder_outputs, output_target_decoder):
      
    #encoder_outputs = [no. of segments, batch size, enc hid dim * 2]
    #output_target_decoder = [batch size, dec hid dim]
    batch_size = encoder_outputs.shape[1]
    src_len = encoder_outputs.shape[0]
    
    #repeat decoder hidden state src_len times
    output_target_decoder = output_target_decoder.unsqueeze(1).repeat(1, src_len, 1)
    
    encoder_outputs = encoder_outputs.permute(1, 0, 2)
    
    #output_target_decoder = [batch size, no. of segments, dec hid dim]
    #encoder_outputs = [batch size, no. of segments, enc hid dim * 2]
    
    energy = torch.tanh(self.attn(torch.cat((output_target_decoder, encoder_outputs), dim = 2))) 
    #energy = [batch size,  no. of segments, dec hid dim]
    attention = self.v(energy).squeeze(2)
    #attention= [batch size,  no. of segments]
    a = F.softmax(attention, dim=1)
    #a = [batch size,  no. of segments]
    a = a.unsqueeze(1)
    #a = [batch size, 1,  no. of segments]
    weighted = torch.bmm(a, encoder_outputs)
    #weighted = [batch size, 1, enc hid dim * 2]
    weighted = weighted.permute(1, 0, 2)
    #weighted = [1, batch size, enc hid dim * 2]
    return weighted
    

"""# Decoder Parameters"""

output_dim = len(TRG.vocab)
DEC_HEADS = 8
DEC_PF_DIM = 512
# embed_dim = 256
# hidden_dim = 256
# segment_dim = 256
# n_layers = 6
# dropout = 0.4
# segment_threshold = 5

"""# Building Decoder"""

class Decoder(nn.Module):
  def __init__(self, output_dim, embed_dim, hidden_dim,segment_dim,n_layers, dropout, attention):
    super().__init__()
    self.output_dim = output_dim
    self.n_layers = n_layers
    self.hidden_dim = hidden_dim
    self.attention = attention
    self.device = device
    self.embedding = nn.Embedding(self.output_dim, embed_dim)
    self.rnn = nn.GRU(embed_dim,hidden_dim,n_layers,dropout=dropout)
    self.rnn = nn.GRU(embed_dim,hidden_dim,n_layers,dropout=dropout)
    self.segmentRnn = nn.GRU(hidden_dim,hidden_dim,n_layers,dropout=dropout)
    self.fc_out = nn.Linear((hidden_dim * 2) + hidden_dim + embed_dim, self.output_dim)
    # self.soft = nn.LogSoftmax(dim=1)
    self.soft = nn.Softmax(dim=1)
    self.dropout = nn.Dropout(dropout)
    
  def stable_softmax(self,x):
    z = x - torch.max(x,dim=1,keepdim=True).values
    numerator = torch.exp(z)
    denominator = torch.sum(numerator, dim=1, keepdims=True)
    softmax = numerator / denominator
    return softmax
  
  def forward(self, input, hidden, encoder_outputs):
          
    #input = [target_len,batch size]
    #hidden = [batch size, dec hid dim]
    #encoder_outputs = [src len, batch size, enc hid dim * 2]
    
    embedded = self.embedding(input)
    #embedded = [target_len, batch size, emb dim]
    
    output_target_decoder,hidden_target_decoder = self.rnn(embedded)
    #output_target_decoder = [target_len, batch size, hidden_dim]
    #hidden_target_decoder = [n layers , batch size, hidden_dim]
    
    trg_len = input.shape[0]
    batch_size = input.shape[1]
    trg_vocab_size = self.output_dim
    # later to be passed in constructor (currently accessing through Globals)
    sop_symbol = TRG.vocab.stoi['<sop>']
    eop_symbol = TRG.vocab.stoi['<eop>']
    
    alpha = torch.zeros(batch_size,trg_len).to(self.device)
    alpha[:,0] = 1
    for end in range(1,trg_len):

      for phraseLen in range(end,0,-1):
        start = end - phraseLen + 1
        weighted = self.attention(encoder_outputs, output_target_decoder[start-1])
        
        sop_vector = (torch.ones(1,batch_size,dtype=torch.int64)*sop_symbol).to(self.device)
        input_phrase = input[start:end+1,:]
        input_phrase = torch.cat((sop_vector,input_phrase),0)
        eop_vector = (torch.ones(1,batch_size,dtype=torch.int64)*eop_symbol).to(self.device)
        input_phrase = torch.cat((input_phrase,eop_vector),0)
        
        phraseEmbedded = self.embedding(input_phrase)
        
        # currEmbedded = phraseEmbedded[0,:,:]
        # rnn_input = torch.cat((currEmbedded.unsqueeze(0), weighted), dim = 2)
        
        phraseProb = torch.ones(batch_size).to(self.device)
        for t in range(input_phrase.shape[0]-1):
          rnn_input = phraseEmbedded[t].unsqueeze(0)
          output, hidden = self.segmentRnn(rnn_input)
          
          output = output.squeeze(0)
          weighted = weighted.squeeze(0)
          rnn_input = rnn_input.squeeze(0)
          
          prediction = self.fc_out(torch.cat((output, weighted, rnn_input), dim = 1))
          #prediction = [batch size, output dim]
          # probabilities = self.soft(prediction)
          # phraseProb *= torch.exp(probabilities[torch.arange(batch_size),input_phrase[t+1]])

          probabilities = self.stable_softmax(prediction)
          phraseProb *= probabilities[torch.arange(batch_size),input_phrase[t+1]]
        
        alpha[:,end] = alpha[:,end].clone() + phraseProb*alpha[:,start-1].clone()
    
    return alpha
      
class NP2MT(nn.Module):
  def __init__(self, encoder, decoder, device):
    super().__init__()
    
    self.encoder = encoder
    self.decoder = decoder
    self.device = device
      
  def forward(self, src, trg, teacher_forcing_ratio = 0.5):
    
    #src = [src len, batch size]
    #trg = [trg len, batch size]
    #teacher_forcing_ratio is probability to use teacher forcing
    #e.g. if teacher_forcing_ratio is 0.75 we use teacher forcing 75% of the time
    
    batch_size = src.shape[1]
    trg_len = trg.shape[0]
    trg_vocab_size = self.decoder.output_dim
    
    ''' moved this to Decoder now
    # later to be passed in constructor (currently accessing through Globals)
    sop_symbol = TRG.vocab.stoi['<sop>']
    eop_symbol = TRG.vocab.stoi['<eop>']
    
    #tensor to store decoder outputs
    outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
    '''
    
    #encoder_outputs is representation of all phrases states of the input sequence, back and forwards
    #hidden is the final forward and backward hidden states, passed through a linear layer (batch_size*hidden_dim)
    encoder_outputs, hidden = self.encoder(src)
    output = self.decoder(trg, hidden, encoder_outputs)
    return output[:,-1]

attn = Attention(hidden_dim, hidden_dim)
enc = Encoder(input_dim, embed_dim, hidden_dim, segment_dim, n_layers, dropout, segment_threshold, device)
dec = Decoder(output_dim, embed_dim, hidden_dim, segment_dim, n_layers, dropout, attn)

model = NP2MT(enc, dec, device).to(device)

def init_weights(m):
  for name, param in m.named_parameters():
    if 'weight' in name:
      nn.init.normal_(param.data, mean=0, std=0.01)
    else:
      nn.init.constant_(param.data, 0)
            
model.apply(init_weights)

def count_parameters(model):
  return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

optimizer = optim.Adam(model.parameters())
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

def train(model, iterator, optimizer, criterion, clip):
  
  model.train()
  
  epoch_loss = 0
  
  for i, batch in enumerate(iterator):
    
    src = batch.src
    trg = batch.trg
    
    optimizer.zero_grad()
    
    output = model(src, trg)
    
    loss = -torch.log(output).mean()

    loss.backward()
    
    # print("Before optimization step\n\n")
    # for name, param in model.named_parameters():
    #   if param.requires_grad:
    #       print(name, torch.isnan(param.data).any())
    
    torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
    
    optimizer.step()

    # print("After optimization step\n\n")

    # for name, param in model.named_parameters():
    #   if param.requires_grad:
    #       print(name, torch.isnan(param.data).any())
    
    epoch_loss += loss.item()
    
  return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion):
    
  model.eval()
  
  epoch_loss = 0
  
  with torch.no_grad():

    for i, batch in enumerate(iterator):

        src = batch.src
        trg = batch.trg

        output = model(src, trg, 0) #turn off teacher forcing

        #trg = [trg len, batch size]
        #output = [trg len, batch size, output dim]

        output_dim = output.shape[-1]
        
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)

        #trg = [(trg len - 1) * batch size]
        #output = [(trg len - 1) * batch size, output dim]

        loss = criterion(output, trg)

        epoch_loss += loss.item()
      
  return epoch_loss / len(iterator)

def epoch_time(start_time, end_time):
  elapsed_time = end_time - start_time
  elapsed_mins = int(elapsed_time / 60)
  elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
  return elapsed_mins, elapsed_secs

N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
  start_time = time.time()
  
  train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
  # valid_loss = evaluate(model, valid_iterator, criterion)
  
  end_time = time.time()
  
  epoch_mins, epoch_secs = epoch_time(start_time, end_time)
  
  # if valid_loss < best_valid_loss:
  #   best_valid_loss = valid_loss
  #   torch.save(model.state_dict(), 'npmt-model.pt')
  
  print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
  print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
  # print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

# model.load_state_dict(torch.load('npmt-model.pt'))

# test_loss = evaluate(model, test_iterator, criterion)

# print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

Thanks @ptrblck for the help, I finally found the issue.
image
These alpha here denotes probability after projection over target vocabulary, I was implementing this equation as is.
As you can see there is a summation over multiplication of probabilities due to this alpha values were underflowing.
I wasn’t aware of this logsumexp trick that could have been used to avoid underflow of probabilities by computing all probabilities in log scale.
If anyone is interested, go through this link to understand the trick. The implementation can be found here - https://pytorch.org/docs/master/generated/torch.logsumexp.html

2 Likes

logsumexp has saved my arse a time or two also :slight_smile: