I think my code is totally in a mess…Every loop in train I want to save self.num_steps
hidden state for further use. And rnn_first
means gru’s inputs are self.num_steps
. If leave out rnn_las
t(Because the bug arises before it), other times except rnn_first
only input the last one of inputs(My data is formed intoself.num_steps
blocks). So I reset gru self.state
to the last one of self.rnn_state
.
I’m not sure if you could understand me or nor…
def rnn(self, rnn_first, rnn_last, inputs_r, target):
rnn_inputs = self.embed(inputs_r)
rnn_inputs = F.alpha_dropout(rnn_inputs, p=self.keep_prob)
self.initial_state = torch.zeros(1, self.batch_size, self.global_dim) #Initialize hidden state
if rnn_first:
self.state = autograd.Variable(self.initial_state)
for index, i in enumerate(rnn_inputs):
rnn_out, self.state = self.gru(i.view(1, self.batch_size, self.global_dim),
self.state)
#Store every hidden state
if index == 0:
state_s = self.state.view(1, self.global_dim)
else:
state_s = torch.cat((state_s, self.state.view(1, self.global_dim)), 0)
self.rnn_states = state_s
elif rnn_last: #If this is the last one, should add the last asin to the rnn(no output needed)
self.state = self.rnn_states[-1].view(1, self.batch_size, self.global_dim)
rnn_input = self.embed(target)
output, self.state = self.gru(rnn_input.view(1, self.batch_size, self.global_dim), self.state)
self.rnn_states = torch.cat(torch.split(self.rnn_states, 1)[1:])
self.rnn_states = torch.cat((self.rnn_states, self.state.view(1, self.global_dim)))
else: #Not the first time_step for each user, process one asin each time
self.state = self.rnn_states[-1].view(1, self.batch_size, self.global_dim)
rnn_input = rnn_inputs[-1]
output, self.state = self.gru(rnn_input.view(1, self.batch_size, self.global_dim), self.state)
self.rnn_states = torch.cat(torch.split(self.rnn_states, 1)[1:])
self.rnn_states = torch.cat((self.rnn_states, self.state.view(1, self.global_dim)))