What is the best way to approach a problem, when the data needs to go through two routes in a network given and argument?

I have a situation where I have 6 stack of decoders on top of each other and after that there are 2 heads and each head is a decoder. Some samples in the batch should go through the first 6 layers and the first head and some samples should go through the first six heads and the second decoder head. given a a batch size, i have a tensor of size batch_size of 1 and 0s which tells me what route the data should take. What is the best way to handle this scenario and train the network?

should i pass all the data through both networks and then slice the outputs based on the output i need or there is a better way to do that without needing to feed the data twice? Can mask variables help here?


Since the entire batch seems to use the first 6 layers, you could directly pass it to them and split the batch afterwards e.g. by indexing. Once you have the two chunks you could then pass them to the corresponding head and if needed concatenate them afterwards (in the same order).