PyTorch Conditional computation not working

I want to implement a conditional computation graph in Pytorch. Since the exact computational graph differs from one data sample to another, I was wondering if batch processing is possible or not. My objective is to achieve the following

def forward(self, x):
if ( > 0):
return self.linear1(x)
return self.linear2(x)

But the above doesn’t work. Am I missing something?

What you want to achieve is entirely possible.

Here you are comparing a tensor against 0, which gives you a tensor (a python object). This object, when evaluated as a bool expression, is True-valued. So in fact it never goes to the second branch.

Hi Simon, can you please tell what is the correct way to implement it? Presently I am totally confused what to do. I have resorted to evaluating graphs on a per-data point basis, and not batch basis. Any help will be very useful.

How about sidestepping the problem like this?

def forward(self, x):
    x_where_self_x_is_positive = x * (self.x > 0).float()
    x_where_self_x_is_negative = x * (self.x < 0).float()
    return self.linear1(x_where_self_x_is_positive) + self.linear2(x_where_self_x_is_negative)

self.x > 0 returns a tensor of boolean values which I cast to float using .float(), this gives me a tensor of zeros and ones where the ones coincide with the positive values in self.x
Then all I have to do is multiply x by this mask and I get a tensor containing values of x where the corresponding element of self.x is positive, and zeros elsewhere.
Doing the same for the negative values in self.x gives me a tensor containing values of x in most of the other positions. I say most because if any element of self.x == 0, then both tensors will contain 0 in the corresponding position.

Finally, I calculate both linear layers, and add up the results.

Maybe I have misunderstood your requirement…

Note that self.linear1(x_where_self_x_is_positive) will not necessarily == 0 at all the positions where self.x == 0

Maybe this fits your requirements better…

def forward(self, x):
    result_where_self_x_is_positive = self.linear1(x) * (self.x > 0).float()
    result_where_self_x_is_negative = self.linear2(x) * (self.x < 0).float()
    return result_where_self_x_is_positive + result_where_self_x_is_negative
1 Like

thanks a lot for the reply…but my problem is that the next layers (which are linear1 and linear2) might not be linear at all…similarly, the condition I apply, might also not be linear…general case if as follows

def forward(self, x):
    if self.NN1(x)>0:
        return self.NN2(x)
        return self.NN3(x)

Basically, for data samples x that have NN1(x)>0, I want to output NN2(NN1(x)), and for data points x that have NN1(x)<0, NN3(NN1(x)). I am wondering whether your solution will cover this case. Thanks for the reply!

I might have completely misunderstood what you want. Does NN1 output a single scalar value for each sample in the batch?

What I’m not understanding is the criterion you have for branching. Is it
that if the first element in data is 0 you go to one of the branches? In
any case, a way to separate a batch is to use the condition to get a mask
(1D with same size as batch size). Then use that mask and it’s inverse to
index into your tensor to get two batches, one for each branch.

Sorry that I forgot to mention this part in previous reply.

Hi Simon, thanks for replying back. The condition can be put forth as follows. Suppose that I have a neural network called DNN (short for decision neural network). Now, if for data point x, the output of DNN, i.e. the quantity DNN(x) is positive, then I decide to use NN1 (short for neural network 1), while if DNN(x) is negative, I decide to use NN2.

yes, you are correct. NN1 outputs a single scalar value for each single data point x.

Then I think my last code sample is correct.

I see. Then my suggestion still applies. The high-level logic would be:

branch = DNN(x)
y1 = NN1(x[branch])
y2 = NN2(x[torch.inverse(branch)])
then merge y1 and y2 together somehow

Hi Simon,
is there a way in which I can avoid computing both NN2 and NN3 outputs and merging them? Because for a single input x, either NN2 needs to be computed, or NN3 needs to be computed. That way I will be able to save some computations. Thanks in advance.

Sorry I’m not sure if I understand. With the approach I suggested, each data point only goes through one net.

Hi Simon,
I am not able to understand how to generate the mask and use it for indexing the tensor x.
The tensor x has dimensions = (batch_size,number _of_features). Now, you mentioned the following will work.

branch = DNN(x)
y1 = NN1(x[branch])
y2 = NN2(x[torch.inverse(branch)])
merge y1 and y2 somehow.

I am wondering what are the values that the variable branch will assume. If branch is binary (0 or 1), then what would x[0] and x[1] mean (note that x has dimensions (batch_size,number _of_features)).

I am also confused what does torch.inverse mean? I searched torch.inverse online, and it seems that torch.inverse returns the matrix inverse.
I think there is a miscommunication. I want a single feature vector to go through only a single NN. Lets say that dataset has N samples called x1,x2,…,xN. Then, each data point xi goes through either NN1 or NN2.
Any help will be very useful. Thanks.