Vectorizing layer usage in pytorch

I have a batch of sentences of the shape: (batch_size, seq_length, emb_size), (batch_size, seq_length, domain_id). My model has separate linear layers for every domain_id. I wish to choose different linear layers according to the domain_id.

For now, my code uses a for loop to do so:

output = [ ]
for id, sentence in zip(domain_ids, batch):
output.append(self.domain_layers[ id ](sentence))

output = torch.stack(output,dim=0)

But this seems to slow down the training a lot. Is there a way to vectorize this part? I tried putting all domain_layers into a nn.ModuleList but nn.ModuleList accepts a single integer for indexing.

Any help is greatly appreciated. Thanks!