When none of the weights or biases is nan, the output of softmax/logsoftmax is nan:
(56, '------------------------------------------------------------')
[0.06725167483091354,
-0.1485
-0.7528
-0.0373
1.3009
1.4958
0.1263
-0.1214
-1.0434
-0.0700
-0.3236
[torch.cuda.FloatTensor of size 10 (GPU 0)]
,
0.4749
0.3053
-0.6457
-2.0994
0.2584
0.0369
0.3065
-0.4833
0.1666
0.1552
-0.4356
-1.1677
0.1893
0.9279
1.0972
0.3539
0.5506
1.6271
-0.7970
2.2702
-0.0906
0.8413
-1.0482
0.5342
-1.1362
0.4732
0.2304
-0.0456
1.7035
0.5210
-2.3329
0.1739
-0.8088
-0.7350
-0.6239
-0.3506
-0.3683
0.5001
0.4215
0.1189
-0.9491
1.5112
-0.0391
-0.2887
-0.5897
-0.0608
-0.5937
-1.1378
-0.5943
-0.3612
1.1985
-0.4854
0.7270
-0.3967
-3.2173
-0.0032
-1.3882
-0.7298
-0.5412
-1.5681
0.9929
0.1989
-0.0215
-0.8898
-0.9923
0.8491
-1.4302
0.7555
1.2239
-1.3638
-0.3672
-0.5601
-0.5400
1.0636
0.2524
-0.3002
1.5756
1.6131
-1.0434
0.0943
1.9642
0.4144
0.3284
-2.7823
-0.5290
-0.6286
0.1034
0.1330
-0.1505
1.1747
0.2803
1.1126
-0.0808
-0.9976
-0.2109
0.8490
0.8504
-0.9229
0.8156
1.1223
[torch.cuda.FloatTensor of size 100 (GPU 0)]
,
-1.0340
0.0444
0.1568
0.0684
1.8081
-0.0894
-0.6941
0.5460
1.8636
-0.6678
[torch.cuda.FloatTensor of size 10 (GPU 0)]
,
-1.5396
-0.1366
1.4076
-1.0740
-0.2066
-0.8502
0.1297
1.0911
-1.7469
-0.3228
-0.2278
1.3163
1.8988
0.8018
0.3237
0.1178
-1.3307
-0.6623
0.3889
-0.0679
-0.5118
-0.7970
-0.4000
0.3433
1.1412
-1.4177
-0.5767
-1.9761
-0.9916
0.4927
0.3704
-0.9567
-1.1610
1.6824
0.3210
0.9481
-1.5084
0.8918
0.0518
-1.5695
-0.2015
0.9679
-0.8095
-0.4063
-1.0446
-0.2454
-1.2855
0.9318
-1.9482
1.7922
[torch.cuda.FloatTensor of size 50 (GPU 0)]
,
-0.6800
-1.0707
0.0036
-1.3215
-0.4377
[torch.cuda.FloatTensor of size 5 (GPU 0)]
,
1.8966
1.6875
0.2880
1.4308
-0.6913
1.1934
0.3914
1.2512
1.2880
1.2402
[torch.cuda.FloatTensor of size 10 (GPU 0)]
,
0.3450
[torch.cuda.FloatTensor of size 1 (GPU 0)]
,
-1.0419
-0.0078
-0.5179
-1.1361
0.3644
0.8918
-1.1277
0.2222
-0.2032
0.9876
0.3112
-0.8020
0.4124
0.3688
-0.3101
1.5793
-0.0403
-1.3706
-0.7739
-0.0688
-0.7700
0.1624
1.9583
-0.3738
-1.7161
0.4519
-0.7681
0.8454
0.4450
0.1789
-2.2388
-0.0367
0.4671
-1.1848
0.6898
0.6365
0.4416
-1.8049
1.2545
-1.9195
0.6064
0.2175
-0.5657
1.4989
-0.0763
-0.9877
-0.0985
0.4143
1.1110
0.1569
-1.1633
-0.1734
-1.3896
-0.0482
1.0555
0.1771
1.2455
-0.2276
0.8230
-0.5937
-0.4750
-0.0964
-0.8256
-0.2190
-0.7099
-1.1294
-0.9057
-0.7088
-0.1171
-0.3420
2.4949
-0.4536
-0.3549
0.2678
0.6307
-0.6698
0.3275
-2.0031
-0.0344
-0.0188
-0.9162
1.5515
1.6544
0.4942
-0.7371
0.0611
-0.7579
1.4853
0.9453
-1.0113
0.8230
1.0770
0.3785
-0.1740
-0.3217
0.4826
-0.2030
-1.8548
2.3336
-0.1775
-1.1423
0.8783
-0.2842
-0.0851
1.3436
-0.4263
-1.3372
-0.0087
-1.0411
1.7750
0.9399
-0.4484
-0.9002
0.7925
-0.2471
1.3614
0.3077
-1.4534
1.0080
-1.5710
1.2226
-0.6587
0.4527
-2.3159
0.2923
-0.1676
-0.6143
-0.1066
0.5795
-1.5959
1.5184
-0.2011
-0.9792
-0.6278
0.4859
2.2493
0.3255
0.9061
-0.7504
-1.4752
0.7377
0.4774
-0.2870
0.2487
-1.3618
-1.2378
0.1503
0.1308
0.0883
0.0725
-0.1624
0.6425
0.1800
-0.7387
-2.9244
1.3067
-0.0021
0.9207
0.5678
-0.9523
0.8005
-0.9672
0.5009
0.1479
-0.0866
0.3007
-0.3735
-0.6966
-0.1322
0.4501
0.9541
-0.7756
0.5686
2.0308
-1.5855
0.1941
-0.3261
0.6155
-0.9840
0.9677
-0.0599
-0.2744
-1.0891
0.6457
0.2703
-0.4215
-0.2493
1.2173
1.0057
0.7497
1.5818
1.1310
0.7364
-0.8641
0.7156
1.1516
1.3216
-0.2373
0.5074
1.0058
-0.2927
1.1482
0.2109
-0.3867
-1.1389
-1.6885
-1.8580
0.4408
0.7163
-0.7085
-0.5378
0.0417
-0.1322
1.0678
1.2119
-2.3265
0.1468
-0.8677
-1.0843
0.3015
-1.0742
0.6670
1.3983
1.0510
-0.3185
0.6670
1.3435
-0.9551
-1.7603
0.2032
0.8776
-0.8286
-0.8512
-0.0504
-1.0502
-1.2580
0.6996
0.2931
0.0692
1.7524
1.2612
0.6342
0.9545
-0.8483
1.3943
-0.3190
-0.1624
-0.9178
-0.1043
-0.6822
-0.2834
-0.4647
-0.0767
-0.1733
0.4145
0.2696
-0.4757
2.3697
0.1903
0.4554
1.3789
-0.5719
0.8460
-1.4476
-1.8165
-0.1643
1.7579
-1.7499
0.5872
-0.8393
-1.1281
0.4593
1.1092
0.2299
-0.7926
-0.2161
-0.7338
-0.2293
0.3858
-0.1763
-0.4824
0.7089
0.4232
-0.0691
-0.5715
-0.4108
-0.1370
-0.7724
-0.3614
1.1042
-2.8018
-0.3088
-0.9896
0.5044
0.8541
-0.3728
0.1562
1.9894
0.8051
-1.3493
-0.3905
-0.4623
0.8004
-0.1328
-0.2932
0.8509
-0.4269
-0.1303
1.3276
-0.6151
0.2905
0.7419
-1.7214
-0.8471
-1.3404
-1.5460
-0.4682
-0.7797
-1.1960
-1.4049
1.3977
-0.9089
-0.8047
0.5468
1.0573
0.8101
-1.5593
-0.0082
1.2189
-0.8134
-0.8398
-0.9467
0.9001
-1.1463
-0.0015
-0.1608
0.7637
1.2424
1.5806
1.2837
0.6798
1.3847
0.6219
0.8579
0.1269
0.2141
-1.5272
0.5338
-1.8261
-0.9368
0.8551
-1.1173
-0.7747
0.9831
-0.0788
-1.7214
-0.2671
0.0038
1.2677
0.3336
-0.4280
-0.7321
-0.9313
0.9207
-0.9986
-0.9474
-0.3904
-0.6657
-0.3081
0.4334
0.9242
-1.2187
-0.1938
-0.2800
0.4470
1.1851
0.2867
-0.1261
-0.2694
-0.7337
-0.7487
1.4571
-0.4093
1.8582
1.9269
0.0747
0.9118
-0.3563
-1.5486
0.8265
-1.4452
0.4821
-0.1916
-0.5119
0.5350
-1.5368
-2.0071
-1.0246
-0.6851
-0.3916
2.1660
-0.6559
-0.5997
-0.5182
0.8739
0.6288
-0.8488
-0.3872
0.2034
-1.7032
-0.1332
0.2808
-0.0770
-0.7199
-0.1557
0.8379
-0.6385
1.3551
0.5616
1.5165
-0.3177
0.8193
1.1990
0.3484
-0.8438
-1.6226
-0.4430
0.8354
0.3002
-0.7115
1.3852
1.2101
-0.3109
-0.8243
0.9840
0.2926
-0.0932
0.0349
-0.4293
-0.4318
-0.7844
-1.2342
-0.0017
1.6312
0.9879
-0.3318
-0.1100
-0.1203
0.2059
0.5987
-0.9275
-0.3186
-2.1391
0.4831
-0.6844
-0.2847
-0.2102
1.1752
-0.0846
0.3425
-0.6217
1.0134
0.1098
-0.0758
-0.3474
-0.4413
-0.8264
1.0374
0.3309
-0.9757
-2.6171
-1.5152
-0.6358
0.8598
1.3015
-0.7317
1.3514
0.6792
1.0313
1.0556
1.3168
-0.8122
-0.3184
0.1092
-0.5435
-0.9175
0.0726
0.3964
-0.6247
-0.6310
-1.1185
0.9178
1.6418
-1.5246
-0.4349
-0.2057
0.1780
0.3718
1.0628
-0.6411
[torch.cuda.FloatTensor of size 500 (GPU 0)]
,
0.3968
-0.6605
-0.5022
-0.0874
-1.0020
0.2743
-0.0293
-0.8302
-0.4403
1.4012
0.9863
-0.1664
-0.1121
-0.4180
0.4350
1.8034
-1.1548
1.5320
-0.1995
-0.1369
-0.0827
0.1759
-0.5761
-0.8507
-1.9461
-0.7054
-0.5667
0.2548
-0.4038
-1.1361
-0.9318
-0.2813
-1.2265
-1.0257
0.3597
-1.0824
-0.2402
-2.7648
1.8429
1.0195
-0.5622
-1.1500
-1.3545
-1.8290
0.4336
-0.3082
-0.3255
0.4136
0.5478
-0.1039
[torch.cuda.FloatTensor of size 50 (GPU 0)]
]
ACTION:
Variable containing:
nan
nan
nan
nan
nan
[torch.cuda.FloatTensor of size 5 (GPU 0)]
model:
import torch
from linear_multi import LinearMulti
from torch import nn
from torch.legacy.nn import Add, Sum, Identity
from torch.autograd import Variable
class Encoder(nn.Module):
def __init__(self, in_dim, hidsz):
super(Encoder, self).__init__()
self.lut = nn.Embedding(in_dim, hidsz) # in_dim agents, returns (batchsz, x, hidsz)
self._bias = nn.Parameter(torch.randn(hidsz))
def forward(self, inp):
x = self.lut(inp)
x = torch.sum(x, 1) # XXX: the original version is sum(2) but lua is 1-indexed
x = x.add(self._bias) # XXX:
return x
class CommNet(nn.Module):
def __init__(self, opts):
super(CommNet, self).__init__()
self.opts = opts
self.nmodels = opts['nmodels']
self.nagents = opts['nagents']
self.hidsz = opts['hidsz']
self.nactions = opts['nactions']
self.use_lstm = opts['model'] == 'lstm'
# Comm
if self.opts['comm_encoder']:
# before merging comm and hidden, use a linear layer for comm
if self.use_lstm: # LSTM has 4x weights for gates
self._comm2hid_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz * 4)
else:
self._comm2hid_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz)
# RNN: (comm + hidden) -> hidden
if self.use_lstm:
self._rnn_enc = self.__build_encoder(self.hidsz * 4)
self._rnn_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz * 4)
else:
self._rnn_enc = self.__build_encoder(self.hidsz)
self._rnn_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz)
# Action layer
self._action_linear = LinearMulti(self.nmodels, self.hidsz, self.nactions)
self._action_baseline_linear = LinearMulti(self.nmodels, self.hidsz, 1)
# Comm_out
self._comm_out_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz * self.nagents)
if self.opts['comm_decoder'] >= 1:
self._comm_out_linear_alt = LinearMulti(self.nmodels, self.hidsz, self.hidsz)
# action_comm
nactions_comm = self.opts['nactions_comm']
if nactions_comm > 1:
self._action_comm_linear = LinearMulti(self.nmodels, self.hidsz, nactions_comm)
def forward(self, inp, prev_hid, prev_cell, model_ids, comm_in):
self.model_ids = model_ids
comm2hid = self.__comm2hid(comm_in)
# below are the return values, for next time step
if self.use_lstm:
hidstate, prev_cell = self.__hidstate(inp, prev_hid, prev_cell, comm2hid)
else:
hidstate = self.__hidstate(inp, prev_hid, prev_cell, comm2hid)
action_prob, baseline = self.__action(hidstate)
comm_out = self.__comm_out(hidstate)
if self.opts['nactions_comm'] > 1:
action_comm = self.__action_comm(hidstate)
return (action_prob, baseline, hidstate, comm_out, action_comm)
else:
return (action_prob, baseline, hidstate, comm_out)
def __comm2hid(self, comm_in):
# Lua Sum(2) -> Python sum(1), shape: [batch x nagents, hidden]
comm2hid = torch.sum(comm_in, 1) # XXX: sum(2) -> 0-index
if self.opts['comm_encoder']:
comm2hid = self._comm2hid_linear(comm2hid, self.model_ids)
return comm2hid
def __hidstate(self, inp, prev_hid, prev_cell, comm2hid):
if self.opts['model'] == 'mlp' or self.opts['model'] == 'rnn':
hidstate = self._rnn(inp, prev_hid, comm2hid)
elif self.use_lstm:
hidstate, cellstate = self._lstm(inp, prev_hid, prev_cell, comm2hid)
return hidstate, cellstate
else:
raise Exception('model not supported')
return hidstate
def _lstm(self, inp, prev_hid, prev_cell, comm_in):
pre_hid = []
pre_hid.append(self._rnn_enc(inp))
pre_hid.append(self._rnn_linear(prev_hid, self.model_ids))
# if comm_in:
pre_hid.append(comm_in)
A = sum(pre_hid)
B = A.view(-1, 4, self.hidsz)
C = torch.split(B, self.hidsz, 0)
gate_forget = nn.Sigmoid()(C[0][0])
gate_write = nn.Sigmoid()(C[0][1])
gate_read = nn.Sigmoid()(C[0][2])
in2c = self.__nonlin()(C[0][3])
# print gate_forget.size(), prev_cell.size()
# print in2c.size(), gate_write.transpose(0,1).size()
cellstate = sum([
torch.matmul(gate_forget, prev_cell),
torch.matmul(in2c.transpose(0,1), gate_write)
])
hidstate = torch.matmul(self.__nonlin()(cellstate), gate_read)
return hidstate, cellstate
def _rnn(self, inp, prev_hid, comm_in):
pre_hid = []
pre_hid.append(self._rnn_enc(inp))
# print("_rnn_linear_weight")
# print(self._rnn_linear.weight_lut.weight)
# print(self._rnn_linear.bias_lut.weight)
pre_hid.append(self._rnn_linear(prev_hid, self.model_ids))
# if comm_in:
pre_hid.append(comm_in)
sum_pre_hid = sum(pre_hid)
hidstate = self.__nonlin()(sum_pre_hid)
return hidstate
def __action(self, hidstate):
# print('action_linear')
# print(self._action_linear.weight_lut.weight)
# print(self._action_linear.bias_lut.weight)
action = self._action_linear(hidstate, self.model_ids)
action_prob = nn.LogSoftmax()(action) # was LogSoftmax
# print('action_baseline_linear')
# print(self._action_baseline_linear.weight_lut.weight)
# print(self._action_baseline_linear.bias_lut.weight)
baseline = self._action_baseline_linear(hidstate, self.model_ids)
return action_prob, baseline
def __comm_out(self, hidstate):
if self.opts['fully_connected']:
# use different params depending on agent ID
# print("comm_out_linear")
# print(self._comm_out_linear.weight_lut.weight.data[0])
# print(self._comm_out_linear.bias_lut.weight)
comm_out = self._comm_out_linear(hidstate, self.model_ids)
else:
comm_out = hidstate
if self.opts['comm_decoder'] >= 1:
comm_out = self._comm_out_linear_alt(comm_out, self.model_ids) # hidsz -> hidsz
if self.opts['comm_decoder'] == 2:
comm_out = self.__nonlin()(comm_out)
comm_out.repeat(self.nagents, 2) # hidsz -> 2 x hidsz # original: comm_out = nn.Contiguous()(nn.Replicate(self.nagents, 2)(comm_out))
return comm_out
def __action_comm(self, hidstate):
action_comm = self._action_comm_linear(hidstate, self.model_ids)
action_comm = nn.LogSoftmax()(action_comm)
return action_comm
def __nonlin(self):
nonlin = self.opts['nonlin']
if nonlin == 'tanh':
return nn.Tanh()
elif nonlin == 'relu':
return nn.ReLU()
elif nonlin == 'none':
return Identity()
else:
raise Exception("wrong nonlin")
def __build_encoder(self, hidsz):
# in_dim = ((self.opts['visibility']*2+1) ** 2) * self.opts['nwords']
in_dim = 1
if self.opts['encoder_lut']: # if there are more than 1 agent, use a LookupTable
return Encoder(in_dim, hidsz)
else: # if only 1 agent
return nn.Linear(in_dim, hidsz)
# import logging as log
# # set logger
import numpy as np
from model import CommNet
from torch.autograd import Variable
from torch import nn
import torch
N_AGENTS = 5
BATCH_SIZE = 1
LEVER = 5
HIDSZ = 10
def train(episode):
opts = {
'comm_encoder': False,
'nonlin': 'relu',
'nactions_comm': 0,
'nwords': 1,
'encoder_lut_nil': None,
'encoder_lut': True,
'hidsz': HIDSZ,
'nmodels': N_AGENTS,
'nagents': N_AGENTS,
'nactions': LEVER,
'model': 'mlp',
'batch_size': BATCH_SIZE,
'fully_connected': True,
'comm_decoder': 0,
}
actor = CommNet(opts).cuda()
print(actor)
inp = Variable(torch.zeros(BATCH_SIZE * N_AGENTS, 1).type(torch.LongTensor), requires_grad=False) # input is none
prev_hid = Variable(torch.zeros(BATCH_SIZE * N_AGENTS, HIDSZ)
.type(torch.FloatTensor), requires_grad=False)
prev_cell = Variable(torch.zeros(BATCH_SIZE * N_AGENTS, HIDSZ), requires_grad=False)
comm_in = Variable(
torch.zeros(BATCH_SIZE * N_AGENTS,
N_AGENTS,
HIDSZ)
.type(torch.FloatTensor), requires_grad=False)
learning_rate = 1e-7
optimizer = torch.optim.SGD(actor.parameters(), lr=learning_rate, momentum=0.9)
loss_fn = torch.nn.MSELoss(size_average=False)
# one hot for mapping action
emb = nn.Embedding(1, 5).cuda()
emb.weight.data = torch.eye(5).cuda()
# clip = 1e-1
# torch.nn.utils.clip_grad_norm(actor.parameters(), clip)
# torch.nn.utils.clip_grad_norm(actor._action_baseline_linear.parameters(), clip)
# # torch.nn.utils.clip_grad_norm(actor._action_comm_linear.parameters(), clip)
# torch.nn.utils.clip_grad_norm(actor._action_linear.parameters(), clip)
# torch.nn.utils.clip_grad_norm(actor._comm_out_linear.parameters(), clip)
# torch.nn.utils.clip_grad_norm(actor._comm2hid_linear.parameters(), clip)
# torch.nn.utils.clip_grad_norm(actor._comm_out_linear_alt.parameters(), clip)
# torch.nn.utils.clip_grad_norm(actor._rnn_enc.parameters(), clip)
# torch.nn.utils.clip_grad_norm(actor._rnn_linear.parameters(), clip)
# torch.nn.utils.clip_grad_norm(actor._action_baseline_linear.parameters(), clip)
ids = np.array([np.random.choice(N_AGENTS, LEVER, replace=False)
for _ in range(BATCH_SIZE)])
# ids shape: [BATCH_SIZE, 5]
model_ids = Variable(torch.from_numpy(np.reshape(ids, (1, -1))), requires_grad=False)
for i in range(episode):
print(i, '------------' * 5)
# print([ w.data[0] for w in list(actor.parameters()) ])
print actor.state_dict().keys()
break
optimizer.zero_grad()
action_prob, _baseline, prev_hid, comm_in = actor.forward(inp.cuda(),
prev_hid.cuda(),
prev_cell.cuda(),
model_ids.cuda(),
comm_in.cuda())
comm_in = comm_in.view(BATCH_SIZE, N_AGENTS, N_AGENTS, HIDSZ)
comm_in = comm_in.transpose(1, 2)
comm_in = comm_in.contiguous().view(BATCH_SIZE * N_AGENTS, N_AGENTS, HIDSZ)
# lever_output = torch.multinomial(action_prob, 1)
# lever_ids = lever_output.view(BATCH_SIZE, LEVER)
# one_hot = emb(lever_ids) # 1x5x5
# distinct_sum = (one_hot.sum(1) > 0).sum(1).type(torch.FloatTensor)
# reward = distinct_sum / LEVER
# loss = - reward
# print(reward.sum(0) / BATCH_SIZE)
# repeat_reward = reward.view(1, BATCH_SIZE).data.repeat(1, LEVER).view(BATCH_SIZE * LEVER, 1)
# lever_output.reinforce(repeat_reward.cuda())
batch_actions = action_prob.sum(0)
print("ACTION:")
print(batch_actions)
target = Variable(torch.ones(LEVER) * BATCH_SIZE, requires_grad=False).cuda()
loss = loss_fn(batch_actions, target)
loss.backward(retain_graph=True)
optimizer.step()
if __name__ == "__main__":
train(10000)