Looking at this post, I wanted to clarify/verify whether or not the train
method is being called on each sequence.
code from the post:
def train(self, input_sequence, init_state):
states = [(None, init_state)]
outputs = []
targets = []
for i, (inp, target) in enumerate(input_sequence):
state = states[-1][1].detach()
state.requires_grad=True
output, new_state = self.one_step_module(inp, state)
outputs.append(output)
targets.append(target)
while len(outputs) > self.k1:
# Delete stuff that is too old
del outputs[0]
del targets[0]
states.append((state, new_state))
while len(states) > self.k2:
# Delete stuff that is too old
del states[0]
if (i+1)%self.k1 == 0:
# loss = self.loss_module(output, target)
optimizer.zero_grad()
# backprop last module (keep graph only if they ever overlap)
start = time.time()
# loss.backward(retain_graph=self.retain_graph)
for j in range(self.k2-1):
if j < self.k1:
loss = self.loss_module(outputs[-j-1], targets[-j-1])
loss.backward(retain_graph=True)
# if we get all the way back to the "init_state", stop
if states[-j-2][0] is None:
break
curr_grad = states[-j-1][0].grad
states[-j-2][1].backward(curr_grad, retain_graph=self.retain_graph)
print("bw: {}".format(time.time()-start))
optimizer.step()
If I understand it correctly, the for
loop is iterating over each item in the sequence, meaning that if model = self.one_step_module = RNN(...)
, then self.one_step_module(inp, state)
is only run on a sequence of length 1.
Then the rest of the code runs backprop on the last k2
items at an interval of k1
items.
I think that would make sense since, if inp
was an entire sequence, then we’d be using state from an unrelated example to inform the next example.
is this all correct?