Forward pass through different branches

Suppose I have this module. If the first node of the output of fc_type is higher than the second node, I want to forward pass through fc_1, else I want to forward pass through fc_2. Currently, my implementation requires me to put a for loop over each input in the batch. Is this sequential or does it get automatically parallelized somehow? If not, is there a better way of parallelizing it? Thanks!

class mm(nn.Module):
    def __init__(self):
        super(mm, self).__init__()
        self.fc_type = nn.Linear(4,2)
        self.fc_1 = nn.Linear(4,1)
        self.fc_2 = nn.Linear(4,1)

    def forward(self, input):
        out = self.fc_type(input)
        y = torch.zeros(out.shape[0],1).to(input)
        for i in range(out.shape[0]):
             if out[i][0] > out[i][1]:
                y[i] = self.fc_1(input[i:i+1])
             else:
                y[i] = self.fc_2(input[i:i+1])
        return y

Instead of iterating each row in out, you could split out before into both inputs, call fc_1 and fc_2, and create the output tensor using the results.
Depending on the batch size, the performance difference might be insignificant.

1 Like