I need to update a variable c_next based on indexed select like this:

( z and zb contain only 1s and 0s)

z_b = z.detach().byte()

zb_b = zb.detach().byte()

update = ((z_b == 0) * zb_b).expand_as(f)

c_next[update] = f * c + i * g

But I can’t use z and zb for indexing because they require gradients, so I do it this way:

c_next = (f * c + i * g) * torch.abs(z - 1) * zb + c_next

which seems more computationally intensive, is there a way to do it in a more simple way while keeping the gradients flowing?

(Part of equation 2 from “Hierarchical Multiscale LSTM” (Chung et. al. 2016 https://arxiv.org/abs/1609.01704 )