Given some nn.Modulelist
object that contains K modules and a pair of tensors where the first tensor contains the data to send and the second the indices of the module in the nn.Modulelist
object, how to forward the tensor to its associated module efficiently ?
The only way I’ve been able to do it is by forwarding tensor in batch to each module of the nn.Modulelist
object. Is there another way that is faster?
import torch
import torch.nn as nn
input_size = 40
output_size = 40
batch_size=1024
number_of_module=10
device=torch.device('cpu')
module_list = nn.ModuleList([nn.Linear(input_size, output_size) for _ in range(number_of_module)])
t_idx = torch.LongTensor(batch_size).random_(0, number_of_module) #indice of the Modulelist oject
data = torch.rand(batch_size, input_size)
next_state = torch.rand((t_idx.size(0), output_size)).to(device)
for ixx in set(t_idx.tolist()):
ixs = torch.where(t_idx == ixx)
next_state[ixs] = module_list[ixx](data[ixs])