# PyTorch Conditional computation not working

I want to implement a conditional computation graph in Pytorch. Since the exact computational graph differs from one data sample to another, I was wondering if batch processing is possible or not. My objective is to achieve the following

def forward(self, x):
if (self.x.data > 0):
return self.linear1(x)
else:
return self.linear2(x)

But the above doesn’t work. Am I missing something?

What you want to achieve is entirely possible.

Here you are comparing a tensor against 0, which gives you a tensor (a python object). This object, when evaluated as a bool expression, is True-valued. So in fact it never goes to the second branch.

Hi Simon, can you please tell what is the correct way to implement it? Presently I am totally confused what to do. I have resorted to evaluating graphs on a per-data point basis, and not batch basis. Any help will be very useful.

How about sidestepping the problem like this?

``````def forward(self, x):
x_where_self_x_is_positive = x * (self.x > 0).float()
x_where_self_x_is_negative = x * (self.x < 0).float()
return self.linear1(x_where_self_x_is_positive) + self.linear2(x_where_self_x_is_negative)
``````

Explanation
`self.x > 0` returns a tensor of boolean values which I cast to float using `.float()`, this gives me a tensor of zeros and ones where the ones coincide with the positive values in self.x
Then all I have to do is multiply `x` by this mask and I get a tensor containing values of `x` where the corresponding element of `self.x` is positive, and zeros elsewhere.
Doing the same for the negative values in `self.x` gives me a tensor containing values of `x` in most of the other positions. I say most because if any element of `self.x` == 0, then both tensors will contain 0 in the corresponding position.

Finally, I calculate both linear layers, and add up the results.

Maybe I have misunderstood your requirement…

Note that `self.linear1(x_where_self_x_is_positive)` will not necessarily == 0 at all the positions where `self.x` == 0

Maybe this fits your requirements better…

``````def forward(self, x):
result_where_self_x_is_positive = self.linear1(x) * (self.x > 0).float()
result_where_self_x_is_negative = self.linear2(x) * (self.x < 0).float()
return result_where_self_x_is_positive + result_where_self_x_is_negative
``````
1 Like

thanks a lot for the reply…but my problem is that the next layers (which are linear1 and linear2) might not be linear at all…similarly, the condition I apply, might also not be linear…general case if as follows

``````def forward(self, x):
if self.NN1(x)>0:
return self.NN2(x)
else:
return self.NN3(x)
``````

Basically, for data samples x that have NN1(x)>0, I want to output NN2(NN1(x)), and for data points x that have NN1(x)<0, NN3(NN1(x)). I am wondering whether your solution will cover this case. Thanks for the reply!

I might have completely misunderstood what you want. Does NN1 output a single scalar value for each sample in the batch?

What I’m not understanding is the criterion you have for branching. Is it
that if the first element in data is 0 you go to one of the branches? In
any case, a way to separate a batch is to use the condition to get a mask
(1D with same size as batch size). Then use that mask and it’s inverse to
index into your tensor to get two batches, one for each branch.

Sorry that I forgot to mention this part in previous reply.

Hi Simon, thanks for replying back. The condition can be put forth as follows. Suppose that I have a neural network called DNN (short for decision neural network). Now, if for data point x, the output of DNN, i.e. the quantity DNN(x) is positive, then I decide to use NN1 (short for neural network 1), while if DNN(x) is negative, I decide to use NN2.

yes, you are correct. NN1 outputs a single scalar value for each single data point x.

Then I think my last code sample is correct.

I see. Then my suggestion still applies. The high-level logic would be:

``````branch = DNN(x)
y1 = NN1(x[branch])
y2 = NN2(x[torch.inverse(branch)])
then merge y1 and y2 together somehow
``````

Hi Simon,
is there a way in which I can avoid computing both NN2 and NN3 outputs and merging them? Because for a single input x, either NN2 needs to be computed, or NN3 needs to be computed. That way I will be able to save some computations. Thanks in advance.

Sorry I’m not sure if I understand. With the approach I suggested, each data point only goes through one net.

Hi Simon,
I am not able to understand how to generate the mask and use it for indexing the tensor x.
The tensor x has dimensions = (batch_size,number _of_features). Now, you mentioned the following will work.

``````branch = DNN(x)
y1 = NN1(x[branch])
y2 = NN2(x[torch.inverse(branch)])
merge y1 and y2 somehow.
``````

I am wondering what are the values that the variable branch will assume. If branch is binary (0 or 1), then what would x[0] and x[1] mean (note that x has dimensions (batch_size,number _of_features)).

I am also confused what does torch.inverse mean? I searched torch.inverse online, and it seems that torch.inverse returns the matrix inverse.
I think there is a miscommunication. I want a single feature vector to go through only a single NN. Lets say that dataset has N samples called x1,x2,…,xN. Then, each data point xi goes through either NN1 or NN2.
Any help will be very useful. Thanks.