What I’m not understanding is the criterion you have for branching. Is it
that if the first element in data is 0 you go to one of the branches? In
any case, a way to separate a batch is to use the condition to get a mask
(1D with same size as batch size). Then use that mask and it’s inverse to
index into your tensor to get two batches, one for each branch.
Sorry that I forgot to mention this part in previous reply.