# Conditional computation that saves computation

I want to use conditional computation of the following form. It uses the output of NN1 in stage 1, in order to decide whether to use NN2 or NN3 in the next stage.

``````def forward(self, x):
y = self.NN1(x)
if y>0:
return self.NN2(x)
else:
return self.NN3(x)
``````

I do not know how to implement it. The existing methods compute both NN2(x) and NN3(x), and then multiply either NN2(x) by 0, or NN3(x) by 0. However, such an implementation does not give savings in terms of computations saved.

1 Like

Hi,

The implementation of the forward method just above is correct, just use that.
You will need to tweak the condition so that it unpacks the content of x out of the Variable and youâ€™re good to go.

Hi Alban,
thanks for a quick reply. So during the forward pass, the above code will use either NN2, or NN3 for generating the output? What about the backward pass? Any help will be very useful.

Well the forward pass will create the computational graph corresponding to what was used during the forward, so if you used NN2, it will backward in NN2, if NN3 was used, it will backward in NN3.
This is actually the main advantage of dynamic graph (like pytorch) vs static graph (like tensorflow), you can do whatever you want in your forward pass, it will backward in whatever was used.

1 Like

The above code will run one of NN2 or NN3 for all the samples in the batch.
If you need to run NN2 for some samples, and NN3 for others, then masking some samples by multiplying by 0 is probably the most efficient method. Doing NN2 and NN3 for all samples and then masking the unwanted results is probably faster than the alternative because the matrix operations are heavily parallelised. Separating out the samples that need to run through NN2 from those that need to run through NN3 and then splicing the results back together properly is almost certainly harder to write and slower to run.

1 Like

Hi Alban,
I think there is a miscommunication. The quantity y = NN1(x) is a scalar for a single data point x. So, I am not able to understand what you meant when you said to " tweak the condition so that it unpacks the content of x out of the Variable ". Can you please elaborate a bit?
Lets say y= NN1(x) returns either 0 or 1 for a single data point x. If we input a batch of 1000 data points, NN1(x) is a 1000 dimensional tensor, consisting of 0s and 1s.
Now, if the i-th element of NN1(x) is 1, I want to output NN2(x), else I want to output NN3(x).

I have slightly edited the condition so as to avoid miscommunication. Any help will be very useful.Thanks.

Maybe I donâ€™t see the difference, but the posted approached from @SimonW and @jpeg729 in this thread seem to be applicable here.

In order to shortcut plenty of misunderstandings, could you add a single line to the code and show us the outputâ€¦

``````def forward(self, x):
y = self.NN1(x)
print(x.size(), y.size()) # <- this line
if y>0:
return self.NN2(x)
else:
return self.NN3(x)``````

Hi, thanks for reply. The techniques posted in that thread do not work. As of present, I donâ€™t know of any techniques in PyTorch which do conditional computation. I have read many works, such as that of ACT and Skipnets, but everyone simply multiplies the outputs by 0! Any help will be very useful.

Blockquote

def forward(self, x):

``````  y = self.NN1(x)
print(x.size(), y.size()) # &lt;- this line
x.size = batch_size,number_features
y.size = batch_size,1

if y&gt;0:
return self.NN2(x)
else:
return self.NN3(x)``````

So the decision of whether to use NN2 or NN3 must be taken separately for each sample in the batch.

``````def forward(self, x):
y = self.NN1(x)
results = []
for i in range(len(y)):
results.append(self.NN2(x[i]) if y[i] else self.NN3(x[i]))
``````

I doubt it will be faster, since NN2 and NN3 are probably composed of highly parallelisable matrix operations that run almost as fast on an entire batch as on a single sample. Besides python loops are slooooooww.

Even assuming that the solution you mention does work, the problem is that many layers have batch normalization, that normalize data based on batches. Over here, we will be simply passing batch size = 1 data.

Batch norm keeps running stats so this shouldnâ€™t break it completely.

Did you test it? Does it run faster?

Hi,

I think a simple way to implement this is using `index_select` and `index_add_` (could use index copy as well).
A proof of concept is below, this is not the nicest code, but that should contain all the informations you need to do what you want:

``````import torch

batch_size = 10
feat_size = 5
x = torch.rand(batch_size, feat_size)
y = torch.randn(batch_size, 1)
print('x', x)

nn2_ind = (y >= 0).squeeze().nonzero().squeeze()
nn3_ind = (y < 0).squeeze().nonzero().squeeze()
print('Indices for NN2', nn2_ind)

for_nn_2 = x.index_select(0, nn2_ind)
for_nn_3 = x.index_select(0, nn3_ind)
print('Input for NN2', for_nn_2)
print('Input for NN3', for_nn_3)

nn_2_out = for_nn_2 + 2 # replace with NN2
nn_3_out = for_nn_3 + 10 # replace with NN3

# Here other sizes may differ when using actual networks
out = torch.zeros(batch_size, feat_size)

print('Output', out)
print('The values that used nn2 should be 2.** and the ones using nn3 should be 10.**')
``````
2 Likes

Hi,
I didnâ€™t try it yet. I will try it soon. Thanks for the reply!

Hi Alban, thanks for the reply! I will check this method and get back in case there are issues with it.

Hi Alban,
your technique works. I am wondering if the network also backpropagates through it?

If pytorch hasnâ€™t complained about inplace operations, missing computation graphs or non-differentiable operations, then it is fairly safe to say that backpropagation works as expected.

Have you timed the various solutions?
Iâ€™d be interested to know which is fastest.

Hi, presently I havenâ€™t timed them. I will do it soon though. Thanks

1 Like