LSTMCell reset specific batched hidden state to zeros during forward pass

To simply the question, we use the sample code from documentation

>>> rnn = nn.LSTMCell(10, 20)
>>> input = torch.randn(6, 3, 10)
>>> hx = torch.randn(3, 20)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
        hx, cx = rnn(input[i], (hx, cx))
        output.append(hx)

Let’s say if in the time step 4 (out of 6) in the sequence, we want to reset the hidden states for the second example (with batch size 3) i.e. termination of an sequential data and next time step new episode comes in.

In this case, what is the right way to do ? Should I use hx[1].fill_(0.0) ? To my understanding, it might break the backward pass, as it is not allowed to modify tensor values in this way.

Another idea I have in mind is to use a mask tensor created on the fly, which does not require gradients
e.g. hx = hx*torch.tensor([1.0, 0.0, 1.0].unsqueeze(1).expand_as(hx))