Hi, I am using LSTM with the A2C algorithm. That means I have multiple instances of agents and environments, each one of them being reset (agent failed/succeeded) at different times.
The problem is that I want to “repack” (as in the pytorch LM example) only specific indices of the hidden states. However trying to repack only part of the hidden state did not go well.
Here is some simple code to illustrate the problem:
async_reset = True
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.lstm = nn.LSTM(10, 5, num_layers=1)
def forward(self, input, hidden):
return self.lstm(input, hidden)
net = SimpleModel()
optimizer = torch.optim.RMSprop(net.parameters())
input = Variable(torch.ones((1, 4, 10)))
hidden = None
last_time = time.time()
odd = True
for i in range(20000):
output, hidden = net(input, hidden)
res = torch.sum(output)
# hidden = Variable(hidden[0].data), Variable(hidden[1].data)
c,h = hidden
if async_reset:
if odd:
c[:, :2] = Variable(c[:, :2].data)
h[:, :2] = Variable(h[:, :2].data)
else:
c[:, 2:] = Variable(c[:, 2:].data)
h[:, 2:] = Variable(h[:, 2:].data)
else:
c = Variable(c.data)
h = Variable(h.data)
print(odd)
odd = not odd
hidden = (c,h)
print("==")
optimizer.zero_grad()
loss = torch.sum(output)
loss.backward(retain_graph=True)
print("%d: Took %.2f second" % (i, time.time() - last_time))
last_time = time.time()
Note that when async_reset is False, the model runs very fast (<0.01 seconds) since there aren’t any previous timestamps to propogate.
When async_reset is True, a different half of the batch is repacked each time. This means that there should always be only one timestamp the model needs to look backwards to, and it should still be very fast.
But the results are different: It takes a lot of time to solve this (0.3 seconds after 230 runs).
I understand from this that the variable is still considering older timestamps. Probably because not the whole variable is detached from the gradients. Other methods such as starting with a zero hidden state and assigning the part of batch that should not be reset didn’t help.
Anyway to prevent this?
Thanks so much!