Suppose I have this module. If the first node of the output of fc_type is higher than the second node, I want to forward pass through fc_1, else I want to forward pass through fc_2. Currently, my implementation requires me to put a for loop over each input in the batch. Is this sequential or does it get automatically parallelized somehow? If not, is there a better way of parallelizing it? Thanks!
class mm(nn.Module):
def __init__(self):
super(mm, self).__init__()
self.fc_type = nn.Linear(4,2)
self.fc_1 = nn.Linear(4,1)
self.fc_2 = nn.Linear(4,1)
def forward(self, input):
out = self.fc_type(input)
y = torch.zeros(out.shape[0],1).to(input)
for i in range(out.shape[0]):
if out[i][0] > out[i][1]:
y[i] = self.fc_1(input[i:i+1])
else:
y[i] = self.fc_2(input[i:i+1])
return y