I want to use conditional computation of the following form. It uses the output of NN1 in stage 1, in order to decide whether to use NN2 or NN3 in the next stage.

def forward(self, x):
y = self.NN1(x)
if y>0:
return self.NN2(x)
else:
return self.NN3(x)

I do not know how to implement it. The existing methods compute both NN2(x) and NN3(x), and then multiply either NN2(x) by 0, or NN3(x) by 0. However, such an implementation does not give savings in terms of computations saved.

The implementation of the forward method just above is correct, just use that.
You will need to tweak the condition so that it unpacks the content of x out of the Variable and youâ€™re good to go.

Hi Alban,
thanks for a quick reply. So during the forward pass, the above code will use either NN2, or NN3 for generating the output? What about the backward pass? Any help will be very useful.

Well the forward pass will create the computational graph corresponding to what was used during the forward, so if you used NN2, it will backward in NN2, if NN3 was used, it will backward in NN3.
This is actually the main advantage of dynamic graph (like pytorch) vs static graph (like tensorflow), you can do whatever you want in your forward pass, it will backward in whatever was used.

The above code will run one of NN2 or NN3 for all the samples in the batch.
If you need to run NN2 for some samples, and NN3 for others, then masking some samples by multiplying by 0 is probably the most efficient method. Doing NN2 and NN3 for all samples and then masking the unwanted results is probably faster than the alternative because the matrix operations are heavily parallelised. Separating out the samples that need to run through NN2 from those that need to run through NN3 and then splicing the results back together properly is almost certainly harder to write and slower to run.

Hi Alban,
I think there is a miscommunication. The quantity y = NN1(x) is a scalar for a single data point x. So, I am not able to understand what you meant when you said to " tweak the condition so that it unpacks the content of x out of the Variable ". Can you please elaborate a bit?
Lets say y= NN1(x) returns either 0 or 1 for a single data point x. If we input a batch of 1000 data points, NN1(x) is a 1000 dimensional tensor, consisting of 0s and 1s.
Now, if the i-th element of NN1(x) is 1, I want to output NN2(x), else I want to output NN3(x).

Hi, thanks for reply. The techniques posted in that thread do not work. As of present, I donâ€™t know of any techniques in PyTorch which do conditional computation. I have read many works, such as that of ACT and Skipnets, but everyone simply multiplies the outputs by 0! Any help will be very useful.

So the decision of whether to use NN2 or NN3 must be taken separately for each sample in the batch.

How about thisâ€¦

def forward(self, x):
y = self.NN1(x)
results = []
for i in range(len(y)):
results.append(self.NN2(x[i]) if y[i] else self.NN3(x[i]))
return torch.cat(results, dim=0)

I doubt it will be faster, since NN2 and NN3 are probably composed of highly parallelisable matrix operations that run almost as fast on an entire batch as on a single sample. Besides python loops are slooooooww.

Even assuming that the solution you mention does work, the problem is that many layers have batch normalization, that normalize data based on batches. Over here, we will be simply passing batch size = 1 data.

I think a simple way to implement this is using index_select and index_add_ (could use index copy as well).
A proof of concept is below, this is not the nicest code, but that should contain all the informations you need to do what you want:

import torch
from torch.autograd import Variable
batch_size = 10
feat_size = 5
x = torch.rand(batch_size, feat_size)
y = torch.randn(batch_size, 1)
print('x', x)
nn2_ind = (y >= 0).squeeze().nonzero().squeeze()
nn3_ind = (y < 0).squeeze().nonzero().squeeze()
print('Indices for NN2', nn2_ind)
for_nn_2 = x.index_select(0, nn2_ind)
for_nn_3 = x.index_select(0, nn3_ind)
print('Input for NN2', for_nn_2)
print('Input for NN3', for_nn_3)
nn_2_out = for_nn_2 + 2 # replace with NN2
nn_3_out = for_nn_3 + 10 # replace with NN3
# Here other sizes may differ when using actual networks
out = torch.zeros(batch_size, feat_size)
out.index_add_(0, nn2_ind, nn_2_out)
out.index_add_(0, nn3_ind, nn_3_out)
print('Output', out)
print('The values that used nn2 should be 2.** and the ones using nn3 should be 10.**')

If pytorch hasnâ€™t complained about inplace operations, missing computation graphs or non-differentiable operations, then it is fairly safe to say that backpropagation works as expected.

Have you timed the various solutions?
Iâ€™d be interested to know which is fastest.