Strange NaN behavior

When none of the weights or biases is nan, the output of softmax/logsoftmax is nan:

(56, '------------------------------------------------------------')
[0.06725167483091354, 
-0.1485
-0.7528
-0.0373
 1.3009
 1.4958
 0.1263
-0.1214
-1.0434
-0.0700
-0.3236
[torch.cuda.FloatTensor of size 10 (GPU 0)]
, 
 0.4749
 0.3053
-0.6457
-2.0994
 0.2584
 0.0369
 0.3065
-0.4833
 0.1666
 0.1552
-0.4356
-1.1677
 0.1893
 0.9279
 1.0972
 0.3539
 0.5506
 1.6271
-0.7970
 2.2702
-0.0906
 0.8413
-1.0482
 0.5342
-1.1362
 0.4732
 0.2304
-0.0456
 1.7035
 0.5210
-2.3329
 0.1739
-0.8088
-0.7350
-0.6239
-0.3506
-0.3683
 0.5001
 0.4215
 0.1189
-0.9491
 1.5112
-0.0391
-0.2887
-0.5897
-0.0608
-0.5937
-1.1378
-0.5943
-0.3612
 1.1985
-0.4854
 0.7270
-0.3967
-3.2173
-0.0032
-1.3882
-0.7298
-0.5412
-1.5681
 0.9929
 0.1989
-0.0215
-0.8898
-0.9923
 0.8491
-1.4302
 0.7555
 1.2239
-1.3638
-0.3672
-0.5601
-0.5400
 1.0636
 0.2524
-0.3002
 1.5756
 1.6131
-1.0434
 0.0943
 1.9642
 0.4144
 0.3284
-2.7823
-0.5290
-0.6286
 0.1034
 0.1330
-0.1505
 1.1747
 0.2803
 1.1126
-0.0808
-0.9976
-0.2109
 0.8490
 0.8504
-0.9229
 0.8156
 1.1223
[torch.cuda.FloatTensor of size 100 (GPU 0)]
, 
-1.0340
 0.0444
 0.1568
 0.0684
 1.8081
-0.0894
-0.6941
 0.5460
 1.8636
-0.6678
[torch.cuda.FloatTensor of size 10 (GPU 0)]
, 
-1.5396
-0.1366
 1.4076
-1.0740
-0.2066
-0.8502
 0.1297
 1.0911
-1.7469
-0.3228
-0.2278
 1.3163
 1.8988
 0.8018
 0.3237
 0.1178
-1.3307
-0.6623
 0.3889
-0.0679
-0.5118
-0.7970
-0.4000
 0.3433
 1.1412
-1.4177
-0.5767
-1.9761
-0.9916
 0.4927
 0.3704
-0.9567
-1.1610
 1.6824
 0.3210
 0.9481
-1.5084
 0.8918
 0.0518
-1.5695
-0.2015
 0.9679
-0.8095
-0.4063
-1.0446
-0.2454
-1.2855
 0.9318
-1.9482
 1.7922
[torch.cuda.FloatTensor of size 50 (GPU 0)]
, 
-0.6800
-1.0707
 0.0036
-1.3215
-0.4377
[torch.cuda.FloatTensor of size 5 (GPU 0)]
, 
 1.8966
 1.6875
 0.2880
 1.4308
-0.6913
 1.1934
 0.3914
 1.2512
 1.2880
 1.2402
[torch.cuda.FloatTensor of size 10 (GPU 0)]
, 
 0.3450
[torch.cuda.FloatTensor of size 1 (GPU 0)]
, 
-1.0419
-0.0078
-0.5179
-1.1361
 0.3644
 0.8918
-1.1277
 0.2222
-0.2032
 0.9876
 0.3112
-0.8020
 0.4124
 0.3688
-0.3101
 1.5793
-0.0403
-1.3706
-0.7739
-0.0688
-0.7700
 0.1624
 1.9583
-0.3738
-1.7161
 0.4519
-0.7681
 0.8454
 0.4450
 0.1789
-2.2388
-0.0367
 0.4671
-1.1848
 0.6898
 0.6365
 0.4416
-1.8049
 1.2545
-1.9195
 0.6064
 0.2175
-0.5657
 1.4989
-0.0763
-0.9877
-0.0985
 0.4143
 1.1110
 0.1569
-1.1633
-0.1734
-1.3896
-0.0482
 1.0555
 0.1771
 1.2455
-0.2276
 0.8230
-0.5937
-0.4750
-0.0964
-0.8256
-0.2190
-0.7099
-1.1294
-0.9057
-0.7088
-0.1171
-0.3420
 2.4949
-0.4536
-0.3549
 0.2678
 0.6307
-0.6698
 0.3275
-2.0031
-0.0344
-0.0188
-0.9162
 1.5515
 1.6544
 0.4942
-0.7371
 0.0611
-0.7579
 1.4853
 0.9453
-1.0113
 0.8230
 1.0770
 0.3785
-0.1740
-0.3217
 0.4826
-0.2030
-1.8548
 2.3336
-0.1775
-1.1423
 0.8783
-0.2842
-0.0851
 1.3436
-0.4263
-1.3372
-0.0087
-1.0411
 1.7750
 0.9399
-0.4484
-0.9002
 0.7925
-0.2471
 1.3614
 0.3077
-1.4534
 1.0080
-1.5710
 1.2226
-0.6587
 0.4527
-2.3159
 0.2923
-0.1676
-0.6143
-0.1066
 0.5795
-1.5959
 1.5184
-0.2011
-0.9792
-0.6278
 0.4859
 2.2493
 0.3255
 0.9061
-0.7504
-1.4752
 0.7377
 0.4774
-0.2870
 0.2487
-1.3618
-1.2378
 0.1503
 0.1308
 0.0883
 0.0725
-0.1624
 0.6425
 0.1800
-0.7387
-2.9244
 1.3067
-0.0021
 0.9207
 0.5678
-0.9523
 0.8005
-0.9672
 0.5009
 0.1479
-0.0866
 0.3007
-0.3735
-0.6966
-0.1322
 0.4501
 0.9541
-0.7756
 0.5686
 2.0308
-1.5855
 0.1941
-0.3261
 0.6155
-0.9840
 0.9677
-0.0599
-0.2744
-1.0891
 0.6457
 0.2703
-0.4215
-0.2493
 1.2173
 1.0057
 0.7497
 1.5818
 1.1310
 0.7364
-0.8641
 0.7156
 1.1516
 1.3216
-0.2373
 0.5074
 1.0058
-0.2927
 1.1482
 0.2109
-0.3867
-1.1389
-1.6885
-1.8580
 0.4408
 0.7163
-0.7085
-0.5378
 0.0417
-0.1322
 1.0678
 1.2119
-2.3265
 0.1468
-0.8677
-1.0843
 0.3015
-1.0742
 0.6670
 1.3983
 1.0510
-0.3185
 0.6670
 1.3435
-0.9551
-1.7603
 0.2032
 0.8776
-0.8286
-0.8512
-0.0504
-1.0502
-1.2580
 0.6996
 0.2931
 0.0692
 1.7524
 1.2612
 0.6342
 0.9545
-0.8483
 1.3943
-0.3190
-0.1624
-0.9178
-0.1043
-0.6822
-0.2834
-0.4647
-0.0767
-0.1733
 0.4145
 0.2696
-0.4757
 2.3697
 0.1903
 0.4554
 1.3789
-0.5719
 0.8460
-1.4476
-1.8165
-0.1643
 1.7579
-1.7499
 0.5872
-0.8393
-1.1281
 0.4593
 1.1092
 0.2299
-0.7926
-0.2161
-0.7338
-0.2293
 0.3858
-0.1763
-0.4824
 0.7089
 0.4232
-0.0691
-0.5715
-0.4108
-0.1370
-0.7724
-0.3614
 1.1042
-2.8018
-0.3088
-0.9896
 0.5044
 0.8541
-0.3728
 0.1562
 1.9894
 0.8051
-1.3493
-0.3905
-0.4623
 0.8004
-0.1328
-0.2932
 0.8509
-0.4269
-0.1303
 1.3276
-0.6151
 0.2905
 0.7419
-1.7214
-0.8471
-1.3404
-1.5460
-0.4682
-0.7797
-1.1960
-1.4049
 1.3977
-0.9089
-0.8047
 0.5468
 1.0573
 0.8101
-1.5593
-0.0082
 1.2189
-0.8134
-0.8398
-0.9467
 0.9001
-1.1463
-0.0015
-0.1608
 0.7637
 1.2424
 1.5806
 1.2837
 0.6798
 1.3847
 0.6219
 0.8579
 0.1269
 0.2141
-1.5272
 0.5338
-1.8261
-0.9368
 0.8551
-1.1173
-0.7747
 0.9831
-0.0788
-1.7214
-0.2671
 0.0038
 1.2677
 0.3336
-0.4280
-0.7321
-0.9313
 0.9207
-0.9986
-0.9474
-0.3904
-0.6657
-0.3081
 0.4334
 0.9242
-1.2187
-0.1938
-0.2800
 0.4470
 1.1851
 0.2867
-0.1261
-0.2694
-0.7337
-0.7487
 1.4571
-0.4093
 1.8582
 1.9269
 0.0747
 0.9118
-0.3563
-1.5486
 0.8265
-1.4452
 0.4821
-0.1916
-0.5119
 0.5350
-1.5368
-2.0071
-1.0246
-0.6851
-0.3916
 2.1660
-0.6559
-0.5997
-0.5182
 0.8739
 0.6288
-0.8488
-0.3872
 0.2034
-1.7032
-0.1332
 0.2808
-0.0770
-0.7199
-0.1557
 0.8379
-0.6385
 1.3551
 0.5616
 1.5165
-0.3177
 0.8193
 1.1990
 0.3484
-0.8438
-1.6226
-0.4430
 0.8354
 0.3002
-0.7115
 1.3852
 1.2101
-0.3109
-0.8243
 0.9840
 0.2926
-0.0932
 0.0349
-0.4293
-0.4318
-0.7844
-1.2342
-0.0017
 1.6312
 0.9879
-0.3318
-0.1100
-0.1203
 0.2059
 0.5987
-0.9275
-0.3186
-2.1391
 0.4831
-0.6844
-0.2847
-0.2102
 1.1752
-0.0846
 0.3425
-0.6217
 1.0134
 0.1098
-0.0758
-0.3474
-0.4413
-0.8264
 1.0374
 0.3309
-0.9757
-2.6171
-1.5152
-0.6358
 0.8598
 1.3015
-0.7317
 1.3514
 0.6792
 1.0313
 1.0556
 1.3168
-0.8122
-0.3184
 0.1092
-0.5435
-0.9175
 0.0726
 0.3964
-0.6247
-0.6310
-1.1185
 0.9178
 1.6418
-1.5246
-0.4349
-0.2057
 0.1780
 0.3718
 1.0628
-0.6411
[torch.cuda.FloatTensor of size 500 (GPU 0)]
, 
 0.3968
-0.6605
-0.5022
-0.0874
-1.0020
 0.2743
-0.0293
-0.8302
-0.4403
 1.4012
 0.9863
-0.1664
-0.1121
-0.4180
 0.4350
 1.8034
-1.1548
 1.5320
-0.1995
-0.1369
-0.0827
 0.1759
-0.5761
-0.8507
-1.9461
-0.7054
-0.5667
 0.2548
-0.4038
-1.1361
-0.9318
-0.2813
-1.2265
-1.0257
 0.3597
-1.0824
-0.2402
-2.7648
 1.8429
 1.0195
-0.5622
-1.1500
-1.3545
-1.8290
 0.4336
-0.3082
-0.3255
 0.4136
 0.5478
-0.1039
[torch.cuda.FloatTensor of size 50 (GPU 0)]
]
ACTION:
Variable containing:
nan
nan
nan
nan
nan
[torch.cuda.FloatTensor of size 5 (GPU 0)]

model:

import torch
from linear_multi import LinearMulti
from torch import nn
from torch.legacy.nn import Add, Sum, Identity
from torch.autograd import Variable

class Encoder(nn.Module):
    def __init__(self, in_dim, hidsz):
        super(Encoder, self).__init__()
        self.lut = nn.Embedding(in_dim, hidsz) # in_dim agents, returns (batchsz, x, hidsz)
        self._bias = nn.Parameter(torch.randn(hidsz))

    def forward(self, inp):
        x = self.lut(inp)
        x = torch.sum(x, 1) # XXX: the original version is sum(2) but lua is 1-indexed
        x = x.add(self._bias) # XXX:
        return x

class CommNet(nn.Module):
    def __init__(self, opts):
        super(CommNet, self).__init__()
        self.opts = opts
        self.nmodels = opts['nmodels']
        self.nagents = opts['nagents']
        self.hidsz = opts['hidsz']
        self.nactions = opts['nactions']
        self.use_lstm = opts['model'] == 'lstm'

        # Comm
        if self.opts['comm_encoder']:
            # before merging comm and hidden, use a linear layer for comm
            if self.use_lstm: # LSTM has 4x weights for gates
                self._comm2hid_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz * 4)
            else:
                self._comm2hid_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz)

        # RNN: (comm + hidden) -> hidden
        if self.use_lstm:
            self._rnn_enc = self.__build_encoder(self.hidsz * 4)
            self._rnn_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz * 4)
        else:
            self._rnn_enc = self.__build_encoder(self.hidsz)
            self._rnn_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz)

        # Action layer
        self._action_linear = LinearMulti(self.nmodels, self.hidsz, self.nactions)
        self._action_baseline_linear = LinearMulti(self.nmodels, self.hidsz, 1)

        # Comm_out
        self._comm_out_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz * self.nagents)
        if self.opts['comm_decoder'] >= 1:
            self._comm_out_linear_alt = LinearMulti(self.nmodels, self.hidsz, self.hidsz)

        # action_comm
        nactions_comm = self.opts['nactions_comm']
        if nactions_comm > 1:
            self._action_comm_linear = LinearMulti(self.nmodels, self.hidsz, nactions_comm)

    def forward(self, inp, prev_hid, prev_cell, model_ids, comm_in):
        self.model_ids = model_ids
        comm2hid = self.__comm2hid(comm_in)
        # below are the return values, for next time step
        if self.use_lstm:
            hidstate, prev_cell = self.__hidstate(inp, prev_hid, prev_cell, comm2hid)
        else:
            hidstate = self.__hidstate(inp, prev_hid, prev_cell, comm2hid)

        action_prob, baseline = self.__action(hidstate)

        comm_out = self.__comm_out(hidstate)

        if self.opts['nactions_comm'] > 1:
            action_comm = self.__action_comm(hidstate)
            return (action_prob, baseline, hidstate, comm_out, action_comm)
        else:
            return (action_prob, baseline, hidstate, comm_out)

    def __comm2hid(self, comm_in):
        # Lua Sum(2) -> Python sum(1), shape: [batch x nagents, hidden]
        comm2hid = torch.sum(comm_in, 1) # XXX: sum(2) -> 0-index
        if self.opts['comm_encoder']:
            comm2hid = self._comm2hid_linear(comm2hid, self.model_ids)
        return comm2hid

    def __hidstate(self, inp, prev_hid, prev_cell, comm2hid):
        if self.opts['model'] == 'mlp' or self.opts['model'] == 'rnn':
            hidstate = self._rnn(inp, prev_hid, comm2hid)
        elif self.use_lstm:
            hidstate, cellstate = self._lstm(inp, prev_hid, prev_cell, comm2hid)
            return hidstate, cellstate
        else:
            raise Exception('model not supported')
        return hidstate

    def _lstm(self, inp, prev_hid, prev_cell, comm_in):
        pre_hid = []
        pre_hid.append(self._rnn_enc(inp))
        pre_hid.append(self._rnn_linear(prev_hid, self.model_ids))
        # if comm_in:
        pre_hid.append(comm_in)
        A = sum(pre_hid)
        B = A.view(-1, 4, self.hidsz)
        C = torch.split(B, self.hidsz, 0)

        gate_forget = nn.Sigmoid()(C[0][0])
        gate_write = nn.Sigmoid()(C[0][1])
        gate_read = nn.Sigmoid()(C[0][2])
        in2c = self.__nonlin()(C[0][3])
        # print gate_forget.size(), prev_cell.size()
        # print in2c.size(), gate_write.transpose(0,1).size()
        cellstate = sum([
            torch.matmul(gate_forget, prev_cell),
            torch.matmul(in2c.transpose(0,1), gate_write)
        ])
        hidstate = torch.matmul(self.__nonlin()(cellstate), gate_read)
        return hidstate, cellstate

    def _rnn(self, inp, prev_hid, comm_in):
        pre_hid = []
        pre_hid.append(self._rnn_enc(inp))

        # print("_rnn_linear_weight")
        # print(self._rnn_linear.weight_lut.weight)
        # print(self._rnn_linear.bias_lut.weight)
        pre_hid.append(self._rnn_linear(prev_hid, self.model_ids))
        # if comm_in:
        pre_hid.append(comm_in)

        sum_pre_hid = sum(pre_hid)
        hidstate = self.__nonlin()(sum_pre_hid)
        return hidstate

    def __action(self, hidstate):
        # print('action_linear')
        # print(self._action_linear.weight_lut.weight)
        # print(self._action_linear.bias_lut.weight)
        action = self._action_linear(hidstate, self.model_ids)
        action_prob = nn.LogSoftmax()(action) # was LogSoftmax

        # print('action_baseline_linear')
        # print(self._action_baseline_linear.weight_lut.weight)
        # print(self._action_baseline_linear.bias_lut.weight)
        baseline =  self._action_baseline_linear(hidstate, self.model_ids)

        return action_prob, baseline

    def __comm_out(self, hidstate):
        if self.opts['fully_connected']:
            # use different params depending on agent ID
            # print("comm_out_linear")
            # print(self._comm_out_linear.weight_lut.weight.data[0])
            # print(self._comm_out_linear.bias_lut.weight)
            comm_out = self._comm_out_linear(hidstate, self.model_ids)
        else:
            comm_out = hidstate
            if self.opts['comm_decoder'] >= 1:
                comm_out = self._comm_out_linear_alt(comm_out, self.model_ids) # hidsz -> hidsz
                if self.opts['comm_decoder'] == 2:
                    comm_out = self.__nonlin()(comm_out)
            comm_out.repeat(self.nagents, 2) # hidsz -> 2 x hidsz # original: comm_out = nn.Contiguous()(nn.Replicate(self.nagents, 2)(comm_out))
        return comm_out

    def __action_comm(self, hidstate):
        action_comm = self._action_comm_linear(hidstate, self.model_ids)
        action_comm = nn.LogSoftmax()(action_comm)
        return action_comm


    def __nonlin(self):
        nonlin = self.opts['nonlin']
        if nonlin == 'tanh':
            return nn.Tanh()
        elif nonlin == 'relu':
            return nn.ReLU()
        elif nonlin == 'none':
            return Identity()
        else:
            raise Exception("wrong nonlin")

    def __build_encoder(self, hidsz):
        # in_dim = ((self.opts['visibility']*2+1) ** 2) * self.opts['nwords']
        in_dim = 1
        if self.opts['encoder_lut']:                   # if there are more than 1 agent, use a LookupTable
            return Encoder(in_dim, hidsz)
        else:                                          # if only 1 agent
            return nn.Linear(in_dim, hidsz)

train.py:

# import logging as log
# # set logger
import numpy as np
from model import CommNet
from torch.autograd import Variable
from torch import nn
import torch

N_AGENTS = 5
BATCH_SIZE = 1
LEVER = 5
HIDSZ = 10


def train(episode):
    opts = {
        'comm_encoder': False,
        'nonlin': 'relu',
        'nactions_comm': 0,
        'nwords': 1,
        'encoder_lut_nil': None,
        'encoder_lut': True,
        'hidsz': HIDSZ,
        'nmodels': N_AGENTS,
        'nagents': N_AGENTS,
        'nactions': LEVER,
        'model': 'mlp',
        'batch_size': BATCH_SIZE,
        'fully_connected': True,
        'comm_decoder': 0,
    }

    actor = CommNet(opts).cuda()
    print(actor)


    inp = Variable(torch.zeros(BATCH_SIZE * N_AGENTS, 1).type(torch.LongTensor), requires_grad=False) # input is none
    prev_hid = Variable(torch.zeros(BATCH_SIZE * N_AGENTS, HIDSZ)
                             .type(torch.FloatTensor), requires_grad=False)
    prev_cell = Variable(torch.zeros(BATCH_SIZE * N_AGENTS, HIDSZ), requires_grad=False)

    comm_in = Variable(
        torch.zeros(BATCH_SIZE * N_AGENTS,
                   N_AGENTS,
                   HIDSZ)
             .type(torch.FloatTensor), requires_grad=False)


    learning_rate = 1e-7
    optimizer = torch.optim.SGD(actor.parameters(), lr=learning_rate, momentum=0.9)
    loss_fn = torch.nn.MSELoss(size_average=False)

    # one hot for mapping action
    emb = nn.Embedding(1, 5).cuda() 
    emb.weight.data = torch.eye(5).cuda()

    # clip = 1e-1
    # torch.nn.utils.clip_grad_norm(actor.parameters(), clip)
    # torch.nn.utils.clip_grad_norm(actor._action_baseline_linear.parameters(), clip)
    # # torch.nn.utils.clip_grad_norm(actor._action_comm_linear.parameters(), clip)
    # torch.nn.utils.clip_grad_norm(actor._action_linear.parameters(), clip)
    # torch.nn.utils.clip_grad_norm(actor._comm_out_linear.parameters(), clip)
    # torch.nn.utils.clip_grad_norm(actor._comm2hid_linear.parameters(), clip)
    # torch.nn.utils.clip_grad_norm(actor._comm_out_linear_alt.parameters(), clip)
    # torch.nn.utils.clip_grad_norm(actor._rnn_enc.parameters(), clip)
    # torch.nn.utils.clip_grad_norm(actor._rnn_linear.parameters(), clip)
    # torch.nn.utils.clip_grad_norm(actor._action_baseline_linear.parameters(), clip)
    ids = np.array([np.random.choice(N_AGENTS, LEVER, replace=False)
                    for _ in range(BATCH_SIZE)])
    # ids shape: [BATCH_SIZE, 5]
    model_ids = Variable(torch.from_numpy(np.reshape(ids, (1, -1))), requires_grad=False)

    for i in range(episode):
        print(i, '------------' * 5)
        # print([ w.data[0] for w in list(actor.parameters()) ])
        print actor.state_dict().keys()
        break
        optimizer.zero_grad()

        action_prob, _baseline, prev_hid, comm_in = actor.forward(inp.cuda(),
                                                                  prev_hid.cuda(),
                                                                  prev_cell.cuda(),
                                                                  model_ids.cuda(),
                                                                  comm_in.cuda())

        comm_in = comm_in.view(BATCH_SIZE, N_AGENTS, N_AGENTS, HIDSZ)
        comm_in = comm_in.transpose(1, 2)
        comm_in = comm_in.contiguous().view(BATCH_SIZE * N_AGENTS, N_AGENTS, HIDSZ)

        # lever_output = torch.multinomial(action_prob, 1)
        # lever_ids = lever_output.view(BATCH_SIZE, LEVER)
        # one_hot = emb(lever_ids) # 1x5x5
        # distinct_sum = (one_hot.sum(1) > 0).sum(1).type(torch.FloatTensor)
        # reward = distinct_sum / LEVER
        # loss = - reward
        # print(reward.sum(0) / BATCH_SIZE)
        # repeat_reward = reward.view(1, BATCH_SIZE).data.repeat(1, LEVER).view(BATCH_SIZE * LEVER, 1)
        # lever_output.reinforce(repeat_reward.cuda())

        batch_actions = action_prob.sum(0)
        print("ACTION:")
        print(batch_actions)
        target = Variable(torch.ones(LEVER) * BATCH_SIZE, requires_grad=False).cuda()
        loss = loss_fn(batch_actions, target)

        loss.backward(retain_graph=True)
        optimizer.step()

if __name__ == "__main__":
    train(10000)