I need to update a variable c_next based on indexed select like this:
( z and zb contain only 1s and 0s)
z_b = z.detach().byte()
zb_b = zb.detach().byte()
update = ((z_b == 0) * zb_b).expand_as(f)
c_next[update] = f * c + i * g
But I can’t use z and zb for indexing because they require gradients, so I do it this way:
c_next = (f * c + i * g) * torch.abs(z - 1) * zb + c_next
which seems more computationally intensive, is there a way to do it in a more simple way while keeping the gradients flowing?
(Part of equation 2 from “Hierarchical Multiscale LSTM” (Chung et. al. 2016 https://arxiv.org/abs/1609.01704 )