Hi there,
How is it possible to perform the following conditional calculations in a vectorized form?
import torch
import torch.nn as nn
class ExampleMultinet(nn.Module):
def __init__(self):
super().__init__()
self.n_features=10-1 # first column of input is used for conditioning
self.net1 = nn.Sequential(nn.Linear(9,5),nn.ReLU(),nn.Linear(5,3))
self.net2 = nn.Sequential(nn.Linear(9,5),nn.ReLU(),nn.Linear(5,3))
self.net3 = nn.Sequential(nn.Linear(9,5),nn.ReLU(),nn.Linear(5,3))
def forward(self, X):
# X dim is (-1,10)
out=[]
for x in X:
if x[0].item()<8:
out.append(self.net1(x[1:].reshape(-1,self.n_features)))
elif (8<=x[0].item()<20):
out.append(self.net2(x[1:].reshape(-1,self.n_features)))
elif x[0].item()>=20:
out.append(self.net3(x[1:].reshape(-1,self.n_features)))
return torch.cat(out,dim=0)
model=ExampleMultinet()
x=torch.randn(4,10)*20
y=model(x)