I want to implement a conditional computation graph in Pytorch. Since the exact computational graph differs from one data sample to another, I was wondering if batch processing is possible or not. My objective is to achieve the following

Here you are comparing a tensor against 0, which gives you a tensor (a python object). This object, when evaluated as a bool expression, is True-valued. So in fact it never goes to the second branch.

Hi Simon, can you please tell what is the correct way to implement it? Presently I am totally confused what to do. I have resorted to evaluating graphs on a per-data point basis, and not batch basis. Any help will be very useful.

Explanation self.x > 0 returns a tensor of boolean values which I cast to float using .float(), this gives me a tensor of zeros and ones where the ones coincide with the positive values in self.x
Then all I have to do is multiply x by this mask and I get a tensor containing values of x where the corresponding element of self.x is positive, and zeros elsewhere.
Doing the same for the negative values in self.x gives me a tensor containing values of x in most of the other positions. I say most because if any element of self.x == 0, then both tensors will contain 0 in the corresponding position.

Finally, I calculate both linear layers, and add up the results.

Maybe I have misunderstood your requirement…

Note that self.linear1(x_where_self_x_is_positive) will not necessarily == 0 at all the positions where self.x == 0

thanks a lot for the reply…but my problem is that the next layers (which are linear1 and linear2) might not be linear at all…similarly, the condition I apply, might also not be linear…general case if as follows

def forward(self, x):
if self.NN1(x)>0:
return self.NN2(x)
else:
return self.NN3(x)

Basically, for data samples x that have NN1(x)>0, I want to output NN2(NN1(x)), and for data points x that have NN1(x)<0, NN3(NN1(x)). I am wondering whether your solution will cover this case. Thanks for the reply!

What I’m not understanding is the criterion you have for branching. Is it
that if the first element in data is 0 you go to one of the branches? In
any case, a way to separate a batch is to use the condition to get a mask
(1D with same size as batch size). Then use that mask and it’s inverse to
index into your tensor to get two batches, one for each branch.

Sorry that I forgot to mention this part in previous reply.

Hi Simon, thanks for replying back. The condition can be put forth as follows. Suppose that I have a neural network called DNN (short for decision neural network). Now, if for data point x, the output of DNN, i.e. the quantity DNN(x) is positive, then I decide to use NN1 (short for neural network 1), while if DNN(x) is negative, I decide to use NN2.

Hi Simon,
is there a way in which I can avoid computing both NN2 and NN3 outputs and merging them? Because for a single input x, either NN2 needs to be computed, or NN3 needs to be computed. That way I will be able to save some computations. Thanks in advance.

Hi Simon,
I am not able to understand how to generate the mask and use it for indexing the tensor x.
The tensor x has dimensions = (batch_size,number _of_features). Now, you mentioned the following will work.

I am wondering what are the values that the variable branch will assume. If branch is binary (0 or 1), then what would x[0] and x[1] mean (note that x has dimensions (batch_size,number _of_features)).

I am also confused what does torch.inverse mean? I searched torch.inverse online, and it seems that torch.inverse returns the matrix inverse.
I think there is a miscommunication. I want a single feature vector to go through only a single NN. Lets say that dataset has N samples called x1,x2,…,xN. Then, each data point xi goes through either NN1 or NN2.
Any help will be very useful. Thanks.