Calculation in forward based on the value of inputs (Xs)

Hi there,
How is it possible to perform the following conditional calculations in a vectorized form?

import torch
import torch.nn as nn

class ExampleMultinet(nn.Module):
    def __init__(self):
 
        super().__init__()
        self.n_features=10-1 # first column of input is used for conditioning
        self.net1 = nn.Sequential(nn.Linear(9,5),nn.ReLU(),nn.Linear(5,3))
        self.net2 = nn.Sequential(nn.Linear(9,5),nn.ReLU(),nn.Linear(5,3))
        self.net3 = nn.Sequential(nn.Linear(9,5),nn.ReLU(),nn.Linear(5,3))

    def forward(self, X):
        # X dim is (-1,10)
        out=[]
        for x in X:
            if x[0].item()<8:
                out.append(self.net1(x[1:].reshape(-1,self.n_features)))
            elif (8<=x[0].item()<20):
                out.append(self.net2(x[1:].reshape(-1,self.n_features)))
            elif x[0].item()>=20:
                out.append(self.net3(x[1:].reshape(-1,self.n_features)))
        
        return torch.cat(out,dim=0)

model=ExampleMultinet()
x=torch.randn(4,10)*20
y=model(x)

Hi @amirmob2000,

Have a look at functorch.vmap, docs here: functorch.vmap — functorch 1.13 documentation

It allows for vectorization of pytorch operations and should work in your case!