Suppose I have this module. If the first node of the output of fc_type is higher than the second node, I want to forward pass through fc_1, else I want to forward pass through fc_2. Currently, my implementation requires me to put a for loop over each input in the batch. Is this sequential or does it get automatically parallelized somehow? If not, is there a better way of parallelizing it? Thanks!
class mm(nn.Module): def __init__(self): super(mm, self).__init__() self.fc_type = nn.Linear(4,2) self.fc_1 = nn.Linear(4,1) self.fc_2 = nn.Linear(4,1) def forward(self, input): out = self.fc_type(input) y = torch.zeros(out.shape,1).to(input) for i in range(out.shape): if out[i] > out[i]: y[i] = self.fc_1(input[i:i+1]) else: y[i] = self.fc_2(input[i:i+1]) return y