I have a situation where I have 6 stack of decoders on top of each other and after that there are 2 heads and each head is a decoder. Some samples in the batch should go through the first 6 layers and the first head and some samples should go through the first six heads and the second decoder head. given a a batch size, i have a tensor of size batch_size of 1 and 0s which tells me what route the data should take. What is the best way to handle this scenario and train the network?
should i pass all the data through both networks and then slice the outputs based on the output i need or there is a better way to do that without needing to feed the data twice? Can mask variables help here?
Thanks.