Embeddings not getting updated

Thank you very much. I have been staring at this code for 2 days now. Can you check why this model isn’t training:

model.py:

import torch
from linear_multi import LinearMulti
from torch import nn
from torch.legacy.nn import Add, Sum, Identity
from torch.autograd import Variable

class Encoder(nn.Module):
    def __init__(self, in_dim, hidsz):
        super(Encoder, self).__init__()
        self.lut = nn.Embedding(in_dim, hidsz) # in_dim agents, returns (batchsz, x, hidsz)
        self.bias = nn.Parameter(torch.randn(hidsz))

    def forward(self, inp):
        x = self.lut(inp)
        x = torch.sum(x, 1) # XXX: the original version is sum(2) but lua is 1-indexed
        x = x.add(self.bias) # XXX:
        return x

class CommNet(nn.Module):
    def __init__(self, opts):
        super(CommNet, self).__init__()
        self.opts = opts
        self.nmodels = opts['nmodels']
        self.nagents = opts['nagents']
        self.hidsz = opts['hidsz']
        self.nactions = opts['nactions']
        self.use_lstm = opts['model'] == 'lstm'

        # Comm
        if self.opts['comm_encoder']:
            # before merging comm and hidden, use a linear layer for comm
            if self.use_lstm: # LSTM has 4x weights for gates
                self._comm2hid_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz * 4)
            else:
                self._comm2hid_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz)

        # RNN: (comm + hidden) -> hidden
        if self.use_lstm:
            self._rnn_enc = self.__build_encoder(self.hidsz * 4)
            self._rnn_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz * 4)
        else:
            self._rnn_enc = self.__build_encoder(self.hidsz)
            self._rnn_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz)

        # Action layer
        self._action_linear = LinearMulti(self.nmodels, self.hidsz, self.nactions)
        self._action_baseline_linear = LinearMulti(self.nmodels, self.hidsz, 1)

        # Comm_out
        self._comm_out_linear = LinearMulti(self.nmodels, self.hidsz, self.hidsz * self.nagents)
        self._comm_out_linear_alt = LinearMulti(self.nmodels, self.hidsz, self.hidsz)

        # action_comm
        nactions_comm = self.opts['nactions_comm']
        if nactions_comm > 1:
            self._action_comm_linear = LinearMulti(self.nmodels, self.hidsz, nactions_comm)

    def forward(self, inp, prev_hid, prev_cell, model_ids, comm_in):
        self.model_ids = model_ids
        comm2hid = self.__comm2hid(comm_in)
        # below are the return values, for next time step
        if self.use_lstm:
            hidstate, prev_cell = self.__hidstate(inp, prev_hid, prev_cell, comm2hid)
        else:
            hidstate = self.__hidstate(inp, prev_hid, prev_cell, comm2hid)

        action_prob, baseline = self.__action(hidstate)

        comm_out = self.__comm_out(hidstate)

        if self.opts['nactions_comm'] > 1:
            action_comm = self.__action_comm(hidstate)
            return (action_prob, baseline, hidstate, comm_out, action_comm)
        else:
            return (action_prob, baseline, hidstate, comm_out)

    def __comm2hid(self, comm_in):
        # Lua Sum(2) -> Python sum(1), shape: [batch x nagents, hidden]
        comm2hid = torch.sum(comm_in, 1) # XXX: sum(2) -> 0-index
        if self.opts['comm_encoder']:
            comm2hid = self._comm2hid_linear(comm2hid, self.model_ids)
        return comm2hid

    def __hidstate(self, inp, prev_hid, prev_cell, comm2hid):
        if self.opts['model'] == 'mlp' or self.opts['model'] == 'rnn':
            hidstate = self._rnn(inp, prev_hid, comm2hid)
        elif self.use_lstm:
            hidstate, cellstate = self._lstm(inp, prev_hid, prev_cell, comm2hid)
            return hidstate, cellstate
        else:
            raise Exception('model not supported')
        return hidstate

    def _lstm(self, inp, prev_hid, prev_cell, comm_in):
        pre_hid = []
        pre_hid.append(self._rnn_enc(inp))
        pre_hid.append(self._rnn_linear(prev_hid, self.model_ids))
        # if comm_in:
        pre_hid.append(comm_in)
        A = sum(pre_hid)
        B = A.view(-1, 4, self.hidsz)
        C = torch.split(B, self.hidsz, 0)

        gate_forget = nn.Sigmoid()(C[0][0])
        gate_write = nn.Sigmoid()(C[0][1])
        gate_read = nn.Sigmoid()(C[0][2])
        in2c = self.__nonlin()(C[0][3])
        print gate_forget.size(), prev_cell.size()
        print in2c.size(), gate_write.transpose(0,1).size()
        cellstate = sum([
            torch.matmul(gate_forget, prev_cell),
            torch.matmul(in2c.transpose(0,1), gate_write)
        ])
        hidstate = torch.matmul(self.__nonlin()(cellstate), gate_read)
        return hidstate, cellstate

    def _rnn(self, inp, prev_hid, comm_in):
        pre_hid = []
        pre_hid.append(self._rnn_enc(inp))

        pre_hid.append(self._rnn_linear(prev_hid, self.model_ids))
        # if comm_in:
        pre_hid.append(comm_in)

        sum_pre_hid = sum(pre_hid)
        hidstate = self.__nonlin()(sum_pre_hid)
        return hidstate

    def __action(self, hidstate):
        print 'action_linear'
        print self._action_linear.weight_lut.weight
        action = self._action_linear(hidstate, self.model_ids)
        action_prob = nn.Softmax()(action) # was LogSoftmax

        baseline =  self._action_baseline_linear(hidstate, self.model_ids)

        return action_prob, baseline

    def __comm_out(self, hidstate):
        if self.opts['fully_connected']:
            # use different params depending on agent ID
            comm_out = self._comm_out_linear(hidstate, self.model_ids)
        else:
            # this is kind of weird, need to consult paper
            # linear from hidsz to hidsz, then non linear, then repeat?
            comm_out = hidstate
            if self.opts['comm_decoder'] >= 1:
                comm_out = self._comm_out_linear_alt(comm_out, self.model_ids) # hidsz -> hidsz
                if self.opts['comm_decoder'] == 2:
                    comm_out = self.__nonlin()(comm_out)
            comm_out.repeat(self.nagents, 2) # hidsz -> 2 x hidsz # original: comm_out = nn.Contiguous()(nn.Replicate(self.nagents, 2)(comm_out))
        return comm_out

    def __action_comm(self, hidstate):
        action_comm = self._action_comm_linear(hidstate, self.model_ids)
        action_comm = nn.LogSoftmax()(action_comm)
        return action_comm


    def __nonlin(self):
        nonlin = self.opts['nonlin']
        if nonlin == 'tanh':
            return nn.Tanh()
        elif nonlin == 'relu':
            return nn.ReLU()
        elif nonlin == 'none':
            return Identity()
        else:
            raise Exception("wrong nonlin")

    def __build_encoder(self, hidsz):
        # in_dim = ((self.opts['visibility']*2+1) ** 2) * self.opts['nwords']
        in_dim = 1
        if self.opts['encoder_lut']:                   # if there are more than 1 agent, use a LookupTable
            return Encoder(in_dim, hidsz)
        else:                                          # if only 1 agent
            return nn.Linear(in_dim, hidsz)

train.py

# import logging as log
# # set logger
# log.basicConfig(level=log.INFO, filename="leaver_train.log")
# console = log.StreamHandler()
# console.setLevel(log.INFO)
# log.getLogger("").addHandler(console)
import numpy as np
from model import CommNet
from torch.autograd import Variable
from torch import nn
import torch

N_AGENTS = 3
BATCH_SIZE = 1
LEVER = 3 
HIDSZ = 3


def train(episode):
    opts = {
        'comm_encoder': True,
        'nonlin': 'tanh',
        'nactions_comm': 0,
        'nwords': 1,
        'encoder_lut_nil': None,
        'encoder_lut': True,
        'hidsz': HIDSZ,
        'nmodels': N_AGENTS * 2,
        'nagents': N_AGENTS,
        'nactions': LEVER,
        'model': 'mlp',
        'batch_size': BATCH_SIZE,
        'fully_connected': True,
        'comm_decoder': 0,
    }

    actor = CommNet(opts).cuda()
    print(actor)


    inp = Variable(torch.zeros(BATCH_SIZE * N_AGENTS, 1).type(torch.LongTensor)) # input is none
    prev_hid = Variable(torch.zeros(BATCH_SIZE * N_AGENTS, HIDSZ)
                             .type(torch.FloatTensor))
    prev_cell = Variable(torch.zeros(BATCH_SIZE * N_AGENTS, HIDSZ))

    comm_in = Variable(
        torch.zeros(BATCH_SIZE * N_AGENTS,
                   N_AGENTS,
                   HIDSZ)
             .type(torch.FloatTensor))


    learning_rate = 1e-7
    optimizer = torch.optim.Adagrad(actor.parameters(), lr=learning_rate)
    loss_fn = torch.nn.MSELoss(size_average=False)

    # one hot for mapping action
    emb = nn.Embedding(1, 5).cuda() 
    emb.weight.data = torch.eye(5).cuda()

    # clip = 1e-1
    # torch.nn.utils.clip_grad_norm(actor.parameters(), clip)
    # torch.nn.utils.clip_grad_norm(actor._action_baseline_linear.parameters(), clip)
    # # torch.nn.utils.clip_grad_norm(actor._action_comm_linear.parameters(), clip)
    # torch.nn.utils.clip_grad_norm(actor._action_linear.parameters(), clip)
    # torch.nn.utils.clip_grad_norm(actor._comm_out_linear.parameters(), clip)
    # torch.nn.utils.clip_grad_norm(actor._comm2hid_linear.parameters(), clip)
    # torch.nn.utils.clip_grad_norm(actor._comm_out_linear_alt.parameters(), clip)
    # torch.nn.utils.clip_grad_norm(actor._rnn_enc.parameters(), clip)
    # torch.nn.utils.clip_grad_norm(actor._rnn_linear.parameters(), clip)
    # torch.nn.utils.clip_grad_norm(actor._action_baseline_linear.parameters(), clip)
    for i in range(episode):
        print i
        optimizer.zero_grad()
        ids = np.array([np.random.choice(N_AGENTS, LEVER, replace=False)
                        for _ in range(BATCH_SIZE)])
        # ids shape: [BATCH_SIZE, 5]
        model_ids = Variable(torch.from_numpy(np.reshape(ids, (1, -1))))


        action_prob, _baseline, prev_hid, comm_in = actor.forward(inp.cuda(),
                                                                 prev_hid.cuda(),
                                                                 prev_cell.cuda(),
                                                                 model_ids.cuda(),
                                                                 comm_in.cuda())

        comm_in = comm_in.view(BATCH_SIZE, N_AGENTS, N_AGENTS, HIDSZ)
        comm_in = comm_in.transpose(1, 2)
        comm_in = comm_in.contiguous().view(BATCH_SIZE * N_AGENTS, N_AGENTS, HIDSZ)

        lever_output = torch.multinomial(action_prob, 1)
        lever_ids = lever_output.view(BATCH_SIZE, LEVER)
        print lever_ids
        one_hot = emb(lever_ids) # 1x5x5
        distinct_sum = (one_hot.sum(1) > 0).sum(1).type(torch.FloatTensor)
        reward = distinct_sum / LEVER

        loss = - reward

        # batch_actions = action_prob.sum(0)
        # target = torch.ones(5) * BATCH_SIZE
        # loss = loss_fn(batch_actions, Variable(target, requires_grad=False))
        print(reward.sum(0) / BATCH_SIZE)
        repeat_reward = reward.view(1, BATCH_SIZE).data.repeat(1, LEVER).view(BATCH_SIZE * LEVER, 1)
        lever_output.reinforce(repeat_reward.cuda())
        loss.backward(retain_graph=True)
        optimizer.step()
        


        # reward = env.step(action_prob)

        # actor.train(ids, base_line=baseline, base_reward=reward, itr=i, log=log)
        # critic.train(ids, base_reward=reward, itr=i, log=log)


if __name__ == "__main__":
    train(10000)