Batch processing when samples should go through different layers at the end of the network

I do have data where not all samples go through same network. To be exact, all data go through the same layers up until some layer in network and after that they diverge based on an id that the sample has. So for instance i have 6 layers of decoder in transformer. all samples go through the first 4 decoders, but the 6th decoder is designed for each specific sample id, and each sample depending on their ID pass through seperate 6th decoder. I am wondering what is the most efficient way to handle that. Obviously one way is to take batches with same id, but this messes up with random sampling. I appreciate any help.