Detaching only specific indices of hidden states in LSTM

Hi, I am using LSTM with the A2C algorithm. That means I have multiple instances of agents and environments, each one of them being reset (agent failed/succeeded) at different times.
The problem is that I want to “repack” (as in the pytorch LM example) only specific indices of the hidden states. However trying to repack only part of the hidden state did not go well.

Here is some simple code to illustrate the problem:

async_reset = True

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.lstm = nn.LSTM(10, 5, num_layers=1)

    def forward(self, input, hidden):
        return self.lstm(input, hidden)

net = SimpleModel()
optimizer = torch.optim.RMSprop(net.parameters())

input = Variable(torch.ones((1, 4, 10)))
hidden = None

last_time = time.time()

odd = True
for i in range(20000):
    output, hidden = net(input, hidden)
    res = torch.sum(output)

    # hidden = Variable(hidden[0].data), Variable(hidden[1].data)
    c,h = hidden

    if async_reset:
        if odd:
            c[:, :2] = Variable(c[:, :2].data)
            h[:, :2] = Variable(h[:, :2].data)
        else:
            c[:, 2:] = Variable(c[:, 2:].data)
            h[:, 2:] = Variable(h[:, 2:].data)
    else:
        c = Variable(c.data)
        h = Variable(h.data)
    print(odd)
    odd = not odd
    hidden = (c,h)
    print("==")

    optimizer.zero_grad()
    loss = torch.sum(output)
    loss.backward(retain_graph=True)

    print("%d: Took %.2f second" % (i, time.time() - last_time))
    last_time = time.time()

Note that when async_reset is False, the model runs very fast (<0.01 seconds) since there aren’t any previous timestamps to propogate.
When async_reset is True, a different half of the batch is repacked each time. This means that there should always be only one timestamp the model needs to look backwards to, and it should still be very fast.
But the results are different: It takes a lot of time to solve this (0.3 seconds after 230 runs).

I understand from this that the variable is still considering older timestamps. Probably because not the whole variable is detached from the gradients. Other methods such as starting with a zero hidden state and assigning the part of batch that should not be reset didn’t help.

Anyway to prevent this?

Thanks so much!

indexing operation like c[:, :2] = Variable(c[:, :2].data) doesn’t help you, it is still preserving history.

What you need to do is create all zeros and then set the particular part:

c_new = torch.zeros_like(c)
c_new.data[:, :2] = c.data[:, :2]
c = c_new

Thanks for your response!
The thing is I actually do want to preserve history, but only for specific rows.
If I’m not mistaken, using your example the entire history will be deleted.

In the example above I delete the history of the first 2 rows every odd steps, and delete the history of the last 2 rows every even step. That way, in theory, the maximum length of history should be 2 steps. However the entire history is preserved.
If I delete the history for all rows the same time (async_reset = false), it will not preserve the history.

Is there a way to implement such a thing, or must I separate it into two separate LSTM calls?

My main motivation is implementing the A2C algorithm where I use BPTT through all timestamps since the beginning of the game, and the problem is that games end in different times for different agents.
Only other solution I can currently think of is to implement A3C instead, unless I am missing something.

Thanks again