Thank you very much. I have been staring at this code for 2 days now. Can you check why this model isn’t training:
import torch
from linear_multi import LinearMulti
from torch import nn
from torch.legacy.nn import Add, Sum, Identity
from torch.autograd import Variable
class Encoder(nn.Module):
def __init__(self, in_dim, hidsz):
super(Encoder, self).__init__()
self.lut = nn.Embedding(in_dim, hidsz) # in_dim agents, returns (batchsz, x, hidsz)
self.bias = nn.Parameter(torch.randn(hidsz))
def forward(self, inp):
x = self.lut(inp)
x = torch.sum(x, 1) # XXX: the original version is sum(2) but lua is 1-indexed
x = x.add(self.bias) # XXX:
return x
class CommNet(nn.Module):
def __init__(self, opts):
super(CommNet, self).__init__()
self.opts = opts
self.nmodels = opts['nmodels']
self.nagents = opts['nagents']
self.hidsz = opts['hidsz']
self.nactions = opts['nactions']
self.use_lstm = opts['model'] == 'lstm'
# Comm
if self.opts['comm_encoder']:
# before merging comm and hidden, use a linear layer for comm
if self.use_lstm: # LSTM has 4x weights for gates
self._comm2hid_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz * 4)
else:
self._comm2hid_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz)
# RNN: (comm + hidden) -> hidden
if self.use_lstm:
self._rnn_enc = self.__build_encoder(self.hidsz * 4)
self._rnn_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz * 4)
else:
self._rnn_enc = self.__build_encoder(self.hidsz)
self._rnn_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz)
# Action layer
self._action_linear = LinearMulti(self.nmodels, self.hidsz, self.nactions)
self._action_baseline_linear = LinearMulti(self.nmodels, self.hidsz, 1)
# Comm_out
self._comm_out_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz * self.nagents)
self._comm_out_linear_alt = LinearMulti(self.nmodels, self.hidsz, self.hidsz)
# action_comm
nactions_comm = self.opts['nactions_comm']
if nactions_comm > 1:
self._action_comm_linear = LinearMulti(self.nmodels, self.hidsz, nactions_comm)
def forward(self, inp, prev_hid, prev_cell, model_ids, comm_in):
self.model_ids = model_ids
comm2hid = self.__comm2hid(comm_in)
# below are the return values, for next time step
if self.use_lstm:
hidstate, prev_cell = self.__hidstate(inp, prev_hid, prev_cell, comm2hid)
else:
hidstate = self.__hidstate(inp, prev_hid, prev_cell, comm2hid)
action_prob, baseline = self.__action(hidstate)
comm_out = self.__comm_out(hidstate)
if self.opts['nactions_comm'] > 1:
action_comm = self.__action_comm(hidstate)
return (action_prob, baseline, hidstate, comm_out, action_comm)
else:
return (action_prob, baseline, hidstate, comm_out)
def __comm2hid(self, comm_in):
# Lua Sum(2) -> Python sum(1), shape: [batch x nagents, hidden]
comm2hid = torch.sum(comm_in, 1) # XXX: sum(2) -> 0-index
if self.opts['comm_encoder']:
comm2hid = self._comm2hid_linear(comm2hid, self.model_ids)
return comm2hid
def __hidstate(self, inp, prev_hid, prev_cell, comm2hid):
if self.opts['model'] == 'mlp' or self.opts['model'] == 'rnn':
hidstate = self._rnn(inp, prev_hid, comm2hid)
elif self.use_lstm:
hidstate, cellstate = self._lstm(inp, prev_hid, prev_cell, comm2hid)
return hidstate, cellstate
else:
raise Exception('model not supported')
return hidstate
def _lstm(self, inp, prev_hid, prev_cell, comm_in):
pre_hid = []
pre_hid.append(self._rnn_enc(inp))
pre_hid.append(self._rnn_linear(prev_hid, self.model_ids))
# if comm_in:
pre_hid.append(comm_in)
A = sum(pre_hid)
B = A.view(-1, 4, self.hidsz)
C = torch.split(B, self.hidsz, 0)
gate_forget = nn.Sigmoid()(C[0][0])
gate_write = nn.Sigmoid()(C[0][1])
gate_read = nn.Sigmoid()(C[0][2])
in2c = self.__nonlin()(C[0][3])
print gate_forget.size(), prev_cell.size()
print in2c.size(), gate_write.transpose(0,1).size()
cellstate = sum([
torch.matmul(gate_forget, prev_cell),
torch.matmul(in2c.transpose(0,1), gate_write)
])
hidstate = torch.matmul(self.__nonlin()(cellstate), gate_read)
return hidstate, cellstate
def _rnn(self, inp, prev_hid, comm_in):
pre_hid = []
pre_hid.append(self._rnn_enc(inp))
pre_hid.append(self._rnn_linear(prev_hid, self.model_ids))
# if comm_in:
pre_hid.append(comm_in)
sum_pre_hid = sum(pre_hid)
hidstate = self.__nonlin()(sum_pre_hid)
return hidstate
def __action(self, hidstate):
print 'action_linear'
print self._action_linear.weight_lut.weight
action = self._action_linear(hidstate, self.model_ids)
action_prob = nn.Softmax()(action) # was LogSoftmax
baseline = self._action_baseline_linear(hidstate, self.model_ids)
return action_prob, baseline
def __comm_out(self, hidstate):
if self.opts['fully_connected']:
# use different params depending on agent ID
comm_out = self._comm_out_linear(hidstate, self.model_ids)
else:
# this is kind of weird, need to consult paper
# linear from hidsz to hidsz, then non linear, then repeat?
comm_out = hidstate
if self.opts['comm_decoder'] >= 1:
comm_out = self._comm_out_linear_alt(comm_out, self.model_ids) # hidsz -> hidsz
if self.opts['comm_decoder'] == 2:
comm_out = self.__nonlin()(comm_out)
comm_out.repeat(self.nagents, 2) # hidsz -> 2 x hidsz # original: comm_out = nn.Contiguous()(nn.Replicate(self.nagents, 2)(comm_out))
return comm_out
def __action_comm(self, hidstate):
action_comm = self._action_comm_linear(hidstate, self.model_ids)
action_comm = nn.LogSoftmax()(action_comm)
return action_comm
def __nonlin(self):
nonlin = self.opts['nonlin']
if nonlin == 'tanh':
return nn.Tanh()
elif nonlin == 'relu':
return nn.ReLU()
elif nonlin == 'none':
return Identity()
else:
raise Exception("wrong nonlin")
def __build_encoder(self, hidsz):
# in_dim = ((self.opts['visibility']*2+1) ** 2) * self.opts['nwords']
in_dim = 1
if self.opts['encoder_lut']: # if there are more than 1 agent, use a LookupTable
return Encoder(in_dim, hidsz)
else: # if only 1 agent
return nn.Linear(in_dim, hidsz)
# import logging as log
# # set logger
# log.basicConfig(level=log.INFO, filename="leaver_train.log")
# console = log.StreamHandler()
# console.setLevel(log.INFO)
# log.getLogger("").addHandler(console)
import numpy as np
from model import CommNet
from torch.autograd import Variable
from torch import nn
import torch
N_AGENTS = 3
BATCH_SIZE = 1
LEVER = 3
HIDSZ = 3
def train(episode):
opts = {
'comm_encoder': True,
'nonlin': 'tanh',
'nactions_comm': 0,
'nwords': 1,
'encoder_lut_nil': None,
'encoder_lut': True,
'hidsz': HIDSZ,
'nmodels': N_AGENTS * 2,
'nagents': N_AGENTS,
'nactions': LEVER,
'model': 'mlp',
'batch_size': BATCH_SIZE,
'fully_connected': True,
'comm_decoder': 0,
}
actor = CommNet(opts).cuda()
print(actor)
inp = Variable(torch.zeros(BATCH_SIZE * N_AGENTS, 1).type(torch.LongTensor)) # input is none
prev_hid = Variable(torch.zeros(BATCH_SIZE * N_AGENTS, HIDSZ)
.type(torch.FloatTensor))
prev_cell = Variable(torch.zeros(BATCH_SIZE * N_AGENTS, HIDSZ))
comm_in = Variable(
torch.zeros(BATCH_SIZE * N_AGENTS,
N_AGENTS,
HIDSZ)
.type(torch.FloatTensor))
learning_rate = 1e-7
optimizer = torch.optim.Adagrad(actor.parameters(), lr=learning_rate)
loss_fn = torch.nn.MSELoss(size_average=False)
# one hot for mapping action
emb = nn.Embedding(1, 5).cuda()
emb.weight.data = torch.eye(5).cuda()
# clip = 1e-1
# torch.nn.utils.clip_grad_norm(actor.parameters(), clip)
# torch.nn.utils.clip_grad_norm(actor._action_baseline_linear.parameters(), clip)
# # torch.nn.utils.clip_grad_norm(actor._action_comm_linear.parameters(), clip)
# torch.nn.utils.clip_grad_norm(actor._action_linear.parameters(), clip)
# torch.nn.utils.clip_grad_norm(actor._comm_out_linear.parameters(), clip)
# torch.nn.utils.clip_grad_norm(actor._comm2hid_linear.parameters(), clip)
# torch.nn.utils.clip_grad_norm(actor._comm_out_linear_alt.parameters(), clip)
# torch.nn.utils.clip_grad_norm(actor._rnn_enc.parameters(), clip)
# torch.nn.utils.clip_grad_norm(actor._rnn_linear.parameters(), clip)
# torch.nn.utils.clip_grad_norm(actor._action_baseline_linear.parameters(), clip)
for i in range(episode):
print i
optimizer.zero_grad()
ids = np.array([np.random.choice(N_AGENTS, LEVER, replace=False)
for _ in range(BATCH_SIZE)])
# ids shape: [BATCH_SIZE, 5]
model_ids = Variable(torch.from_numpy(np.reshape(ids, (1, -1))))
action_prob, _baseline, prev_hid, comm_in = actor.forward(inp.cuda(),
prev_hid.cuda(),
prev_cell.cuda(),
model_ids.cuda(),
comm_in.cuda())
comm_in = comm_in.view(BATCH_SIZE, N_AGENTS, N_AGENTS, HIDSZ)
comm_in = comm_in.transpose(1, 2)
comm_in = comm_in.contiguous().view(BATCH_SIZE * N_AGENTS, N_AGENTS, HIDSZ)
lever_output = torch.multinomial(action_prob, 1)
lever_ids = lever_output.view(BATCH_SIZE, LEVER)
print lever_ids
one_hot = emb(lever_ids) # 1x5x5
distinct_sum = (one_hot.sum(1) > 0).sum(1).type(torch.FloatTensor)
reward = distinct_sum / LEVER
loss = - reward
# batch_actions = action_prob.sum(0)
# target = torch.ones(5) * BATCH_SIZE
# loss = loss_fn(batch_actions, Variable(target, requires_grad=False))
print(reward.sum(0) / BATCH_SIZE)
repeat_reward = reward.view(1, BATCH_SIZE).data.repeat(1, LEVER).view(BATCH_SIZE * LEVER, 1)
lever_output.reinforce(repeat_reward.cuda())
loss.backward(retain_graph=True)
optimizer.step()
# reward = env.step(action_prob)
# actor.train(ids, base_line=baseline, base_reward=reward, itr=i, log=log)
# critic.train(ids, base_reward=reward, itr=i, log=log)
if __name__ == "__main__":
train(10000)