Wrapper nn.Module class and non-standard forward() method

Hi. I’m new to pytorch. I have a project I’m working on that uses the babi data set. My code is very messy and I want to show as little of it as I can. I have some modules that I use and one of them is a wrapper for the other ones. The wrapper module has several methods in it besides the ‘forward’ method. These methods are called in the wrapper’s forward method. Do I have to worry about this setup? Will my code train properly? In fact I am trying to fix a problem that I have where my model does not train well after reaching the 50% accuracy mark. Could this somehow be related to my wrapper’s forward method?

class WrapMemRNN(nn.Module):
    def __init__(self,vocab_size, embed_dim,  hidden_size, n_layers, dropout=0.3, do_babi=True, bad_token_lst=[], freeze_embedding=False, embedding=None):
        super(WrapMemRNN, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.do_babi = do_babi
        self.bad_token_lst = bad_token_lst
        self.embedding = embedding
        self.freeze_embedding = freeze_embedding
        self.teacher_forcing_ratio = hparams['teacher_forcing_ratio']
        self.model_1_enc = Encoder(vocab_size, embed_dim, hidden_size, n_layers, dropout=dropout,embedding=embedding, bidirectional=False)
        self.model_2_enc = Encoder(vocab_size, embed_dim, hidden_size, n_layers, dropout=dropout, embedding=embedding, bidirectional=False)

        self.model_3_mem_a = MemRNN(hidden_size, dropout=dropout)
        self.model_3_mem_b = MemRNN(hidden_size, dropout=dropout)
        self.model_4_att = EpisodicAttn(hidden_size, dropout=dropout)
        self.model_5_ans = AnswerModule(vocab_size, hidden_size,dropout=dropout)

        self.input_var = None  # for input
        self.q_var = None  # for question
        self.answer_var = None  # for answer
        self.q_q = None  # extra question
        self.inp_c = None  # extra input
        self.inp_c_seq = None
        self.all_mem = None
        self.last_mem = None  # output of mem unit
        self.prediction = None  # final single word prediction
        self.memory_hops = hparams['babi_memory_hops']

        if self.freeze_embedding or self.embedding is not None:
            self.new_freeze_embedding()
        #self.criterion = nn.CrossEntropyLoss()

        pass

    def forward(self, input_variable, question_variable, target_variable, criterion=None):

        self.new_input_module(input_variable, question_variable)
        self.new_episodic_module()
        outputs,  loss = self.new_answer_module_simple(target_variable, criterion)

        return outputs, None, loss, None

    def new_freeze_embedding(self):
        self.model_1_enc.embed.weight.requires_grad = False
        self.model_2_enc.embed.weight.requires_grad = False
        print('freeze embedding')
        pass

    def new_input_module(self, input_variable, question_variable):

        out1, hidden1 = self.model_1_enc(input_variable)

        self.inp_c_seq = out1
        self.inp_c = hidden1 #out1

        out2, hidden2 = self.model_2_enc(question_variable)

        self.q_q = hidden2

        return


    def new_episodic_module(self):
        if True:

            m_list = []
            g_list = []
            e_list = []
            f_list = []
            m = self.q_q.clone()
            g = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
            e = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
            f = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
            m_list.append(m)
            g_list.append(g)
            e_list.append(e)
            f_list.append(f)

            #m_list.append(self.q_q.clone())

            for iter in range(self.memory_hops):

                #g_list.append(g)
                #e_list.append(e)

                sequences = self.inp_c_seq.clone().permute(1,0,2).squeeze(0)

                for i in range(len(sequences)):
                #if True:
                    x = self.new_attention_step(sequences[i], g_list[-1], m_list[-1], self.q_q)
                    g_list.append(x)


                for i in range(len(sequences)):
                #if True:
                    e, f = self.new_episode_small_step(sequences[i], g_list[-1], e_list[-1])
                    e_list.append(e)
                    f_list.append(f)

                _, out = self.model_3_mem_a( e_list[-1], m_list[-1])#, g_list[-1])
                m_list.append(out)

            self.last_mem = m_list[-1]

        return m_list[-1]



    def new_episode_small_step(self, ct, g, prev_h):

        _ , gru = self.model_3_mem_a(ct, prev_h, None) # g
        h = g * gru + (1 - g) * prev_h

        return h, gru

    def new_attention_step(self, ct, prev_g, mem, q_q):
        #mem = mem.view(-1, self.hidden_size)

        concat_list = [
            #prev_g.view(-1, self.hidden_size),
            ct.unsqueeze(0),#.view(self.hidden_size,-1),
            mem.squeeze(0),
            q_q.squeeze(0),
            (ct * q_q).squeeze(0),
            (ct * mem).squeeze(0),
            torch.abs(ct - q_q).squeeze(0),
            torch.abs(ct - mem).squeeze(0)
        ]
        #for i in concat_list: print(i.size())
        #exit()
        return self.model_4_att(concat_list)



    def new_answer_module_simple(self,target_var, criterion):

        loss = 0
        ansx = self.model_5_ans(self.last_mem, self.q_q)
        #ans = ansx.data.max(dim=1)[1]
        ans = torch.argmax(ansx,dim=1)[0]
        if criterion is not None:
            loss = criterion(ansx, target_var[0])

        return [ans], loss

        pass

I don’t know why your model doesn’t train well, but calling other methods in your forward method probably isn’t the problem.

thanks. I wasn’t sure and felt I had to ask. Thanks again.

My model still has problems. I have opened a stack overflow question. If you are interested in the issue in general check out this link: https://stackoverflow.com/questions/51154949/dmn-neural-network-with-poor-validation-results-only-50 . thanks.