How data in a batch is processed by different network

I have a batch of data:

[[0.1, 0.3],
 [0.5, 0.2],
 [0.7, 0.8],
 [0.6, 0.6],
 [0.2, 0.9]]

I have two networks, mlp1() and mlp2()

How can I process the data [0.7, 0.8] and [0.2, 0.9] with mlp1(), and the rest with mlp2() to get:

[mlp2([0.1, 0.3]),
 mlp2([0.5, 0.2]),
 mlp1([0.7, 0.8]),
 mlp2([0.6, 0.6]),
 mlp1([0.2, 0.9])]

Any suggestions are appreciated!

Iā€™m not sure, if I understand the use case correctly, but you should be able to slice the input and pass the corresponding parts to the appropriate model.

Hi ptrblck, exactly, I will have to slice the input and feed to networks respectively. However, after slicing the input, I will get two parts, one for mlp1() and the other for mlp2(). The tricky part is that I need to merge the two parts back into a holistic tensor after forwarding the network. Any hints?

Assuming you have precomputed the indices, which would be used for the splitting, you could recreate the order by indexing into the result tensor:

x = torch.tensor([[0.1, 0.3],
                  [0.5, 0.2],
                  [0.7, 0.8],
                  [0.6, 0.6],
                  [0.2, 0.9]])

mlp1_idx = torch.tensor([2, 4])
mlp2_idx = torch.tensor([0, 1, 3,])

x1 = x[mlp1_idx]
x2 = x[mlp2_idx]


# process
...


res = torch.zeros_like(x)
res[torch.cat((mlp1_idx, mlp2_idx), dim=0)] = torch.cat((x1, x2), dim=0)
print(res==x)

Depending on your work flow you might also be able to use e.g. scatter.

1 Like