Hi. I’m new to pytorch. I have a project I’m working on that uses the babi data set. My code is very messy and I want to show as little of it as I can. I have some modules that I use and one of them is a wrapper for the other ones. The wrapper module has several methods in it besides the ‘forward’ method. These methods are called in the wrapper’s forward method. Do I have to worry about this setup? Will my code train properly? In fact I am trying to fix a problem that I have where my model does not train well after reaching the 50% accuracy mark. Could this somehow be related to my wrapper’s forward method?
class WrapMemRNN(nn.Module):
def __init__(self,vocab_size, embed_dim, hidden_size, n_layers, dropout=0.3, do_babi=True, bad_token_lst=[], freeze_embedding=False, embedding=None):
super(WrapMemRNN, self).__init__()
self.hidden_size = hidden_size
self.n_layers = n_layers
self.do_babi = do_babi
self.bad_token_lst = bad_token_lst
self.embedding = embedding
self.freeze_embedding = freeze_embedding
self.teacher_forcing_ratio = hparams['teacher_forcing_ratio']
self.model_1_enc = Encoder(vocab_size, embed_dim, hidden_size, n_layers, dropout=dropout,embedding=embedding, bidirectional=False)
self.model_2_enc = Encoder(vocab_size, embed_dim, hidden_size, n_layers, dropout=dropout, embedding=embedding, bidirectional=False)
self.model_3_mem_a = MemRNN(hidden_size, dropout=dropout)
self.model_3_mem_b = MemRNN(hidden_size, dropout=dropout)
self.model_4_att = EpisodicAttn(hidden_size, dropout=dropout)
self.model_5_ans = AnswerModule(vocab_size, hidden_size,dropout=dropout)
self.input_var = None # for input
self.q_var = None # for question
self.answer_var = None # for answer
self.q_q = None # extra question
self.inp_c = None # extra input
self.inp_c_seq = None
self.all_mem = None
self.last_mem = None # output of mem unit
self.prediction = None # final single word prediction
self.memory_hops = hparams['babi_memory_hops']
if self.freeze_embedding or self.embedding is not None:
self.new_freeze_embedding()
#self.criterion = nn.CrossEntropyLoss()
pass
def forward(self, input_variable, question_variable, target_variable, criterion=None):
self.new_input_module(input_variable, question_variable)
self.new_episodic_module()
outputs, loss = self.new_answer_module_simple(target_variable, criterion)
return outputs, None, loss, None
def new_freeze_embedding(self):
self.model_1_enc.embed.weight.requires_grad = False
self.model_2_enc.embed.weight.requires_grad = False
print('freeze embedding')
pass
def new_input_module(self, input_variable, question_variable):
out1, hidden1 = self.model_1_enc(input_variable)
self.inp_c_seq = out1
self.inp_c = hidden1 #out1
out2, hidden2 = self.model_2_enc(question_variable)
self.q_q = hidden2
return
def new_episodic_module(self):
if True:
m_list = []
g_list = []
e_list = []
f_list = []
m = self.q_q.clone()
g = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
e = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
f = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
m_list.append(m)
g_list.append(g)
e_list.append(e)
f_list.append(f)
#m_list.append(self.q_q.clone())
for iter in range(self.memory_hops):
#g_list.append(g)
#e_list.append(e)
sequences = self.inp_c_seq.clone().permute(1,0,2).squeeze(0)
for i in range(len(sequences)):
#if True:
x = self.new_attention_step(sequences[i], g_list[-1], m_list[-1], self.q_q)
g_list.append(x)
for i in range(len(sequences)):
#if True:
e, f = self.new_episode_small_step(sequences[i], g_list[-1], e_list[-1])
e_list.append(e)
f_list.append(f)
_, out = self.model_3_mem_a( e_list[-1], m_list[-1])#, g_list[-1])
m_list.append(out)
self.last_mem = m_list[-1]
return m_list[-1]
def new_episode_small_step(self, ct, g, prev_h):
_ , gru = self.model_3_mem_a(ct, prev_h, None) # g
h = g * gru + (1 - g) * prev_h
return h, gru
def new_attention_step(self, ct, prev_g, mem, q_q):
#mem = mem.view(-1, self.hidden_size)
concat_list = [
#prev_g.view(-1, self.hidden_size),
ct.unsqueeze(0),#.view(self.hidden_size,-1),
mem.squeeze(0),
q_q.squeeze(0),
(ct * q_q).squeeze(0),
(ct * mem).squeeze(0),
torch.abs(ct - q_q).squeeze(0),
torch.abs(ct - mem).squeeze(0)
]
#for i in concat_list: print(i.size())
#exit()
return self.model_4_att(concat_list)
def new_answer_module_simple(self,target_var, criterion):
loss = 0
ansx = self.model_5_ans(self.last_mem, self.q_q)
#ans = ansx.data.max(dim=1)[1]
ans = torch.argmax(ansx,dim=1)[0]
if criterion is not None:
loss = criterion(ansx, target_var[0])
return [ans], loss
pass