I have a batch of sentences of the shape: (batch_size, seq_length, emb_size), (batch_size, seq_length, domain_id). My model has separate linear layers for every domain_id. I wish to choose different linear layers according to the domain_id.
For now, my code uses a for loop to do so:
output = [ ]
for id, sentence in zip(domain_ids, batch):
output.append(self.domain_layers[ id ](sentence))
output = torch.stack(output,dim=0)
But this seems to slow down the training a lot. Is there a way to vectorize this part? I tried putting all domain_layers into a nn.ModuleList but nn.ModuleList accepts a single integer for indexing.
Any help is greatly appreciated. Thanks!