How to do a batch training, when you have multiple neural networks and you have to select through which network you have to pass your input

class NN(torch.nn.Module):
def init(self,H):
super(NN,self).init()

            self.nn_1 = nn.Linear(1,H)
            self.nn_2 = nn.Linear(H,1)

def forward(self, x):
	#print(x)
	#inp = Variable(torch.Tensor(x)).long()
	l1 = F.relu(self.nn_1(x))
	l2 = F.relu(self.nn_2(l1))
	return l2

class TwoLayerNet(torch.nn.Module):
def init(self):
super(TwoLayerNet,self).init()

	self.nn1 = NN(4)
	self.nn2 = NN(4)
	self.nn3 = NN(4)
	self.nn4 = NN(4)
	self.nn5 = NN(4)
	self.nn6 = NN(4)
def forward(self,x,n):		
	nn_o = torch.zeros(1)
	
	#how to define this part?,I have to select the network by checking the
	#corresponding element in the tensor/list 'n' and then pass the corresponding 
	#element in x to the corresponding nn.I dont understand how to do batch training in such scenario
	for i in range(len(x[..something_here..])):
		if n[i] == 1:
			nn_o += self.nn1(x[i])
		elif n[i]==2:
			nn_o += self.nn2(x[i])
		elif n[i]==3:
			nn_o += self.nn3(x[i])
		elif n[i]==4:
			nn_o += self.nn4(x[i])
		elif n[i]==5:
			nn_o += self.nn5(x[i])	
		elif n[i]==6:
			nn_o += self.nn6(x[i])
	return nn_o

model= TwoLayerNet()

criterion = nn.MSELoss(size_average=False)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)

#a=(1.2,1)
#b=(1.5,2)
#c=(1.2,1)

values = torch.Tensor([[[1.0],[2.2],[3.1],[4.1]],[[1.5],[2.4],[3.1],[4.5]]])
y=torch.Tensor([[10],[24]])
nnselector =torch.Tensor([[1,1,1,1],[2,3,4,1]])
for i in range(10):
y_pred = model(values,nnselector)
print (y_pred)
loss = criterion(y_pred,y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

@Yashaswi_Pathak Did you get the solution of this network?

One thing that I can think of is passing the data to all the networks and use the n as selector of this output. This way, number of forward pass you have to do grows with the number of networks instead of the data.

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelectNet(nn.Module):
    """
    Forward pass by selecting from multiple networks.
    """

    def __init__(self, num_nets: int, input_dim: int, output_dim: int):
        """
        Constructor.

        :param num_nets: Number of networks.
        :param input_dim: Input dimension.
        :param output_dim: Output dimension.
        """
        self.num_nets = num_nets
        self.nets = nn.ModuleList([
            nn.Linear(input_dim, output_dim) for _ in range(num_nets)
        ])
    
    def forward(self, x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.

        :param x: Input tensor of shape (batch_size, input_dim).
        :param n: Index tensor of shape (batch_size).

        Output:
            Output tensor of shape (batch_size, output_dim).
        """
        outputs = []
        for net in self.nets:
            output = net(x)                     ## (batch_size, output_dim)
            outputs.append(output)

        outputs = torch.stack(outputs, dim=1)   ## (batch_size, num_nets, output_dim)
        selector = F.one_hot(n, self.num_nets)  ## (batch_size, num_nets)
        selector = selector.unsqueeze(-1)       ## (batch_size, num_nets, 1)

        output = outputs * selector             ## (batch_size, num_nets, output_dim)
        output = output.sum(1)                  ## (batch_size, output_dim)
        return output

num_nets = 3
input_dim = 4
output_dim = 1
batch_size = 3

net = SelectNet(num_nets, input_dim, output_dim)
x = torch.randn(batch_size, input_dim)
n = torch.arange(batch_size).long()
print(net(x, n))