Conditional computation that saves computation

Hi, thanks for reply. The techniques posted in that thread do not work. As of present, I don’t know of any techniques in PyTorch which do conditional computation. I have read many works, such as that of ACT and Skipnets, but everyone simply multiplies the outputs by 0! Any help will be very useful.

Blockquote

def forward(self, x):

  y = self.NN1(x)
  print(x.size(), y.size()) # <- this line
  x.size = batch_size,number_features
  y.size = batch_size,1 
  
  if y>0:
     return self.NN2(x)
 else:
     return self.NN3(x)

So the decision of whether to use NN2 or NN3 must be taken separately for each sample in the batch.

How about this…

def forward(self, x):
    y = self.NN1(x)
    results = []
    for i in range(len(y)):
        results.append(self.NN2(x[i]) if y[i] else self.NN3(x[i]))
    return torch.cat(results, dim=0)

I doubt it will be faster, since NN2 and NN3 are probably composed of highly parallelisable matrix operations that run almost as fast on an entire batch as on a single sample. Besides python loops are slooooooww.

Even assuming that the solution you mention does work, the problem is that many layers have batch normalization, that normalize data based on batches. Over here, we will be simply passing batch size = 1 data.

Batch norm keeps running stats so this shouldn’t break it completely.

Did you test it? Does it run faster?

Hi,

I think a simple way to implement this is using index_select and index_add_ (could use index copy as well).
A proof of concept is below, this is not the nicest code, but that should contain all the informations you need to do what you want:

import torch
from torch.autograd import Variable

batch_size = 10
feat_size = 5
x = torch.rand(batch_size, feat_size)
y = torch.randn(batch_size, 1)
print('x', x)

nn2_ind = (y >= 0).squeeze().nonzero().squeeze()
nn3_ind = (y < 0).squeeze().nonzero().squeeze()
print('Indices for NN2', nn2_ind)

for_nn_2 = x.index_select(0, nn2_ind)
for_nn_3 = x.index_select(0, nn3_ind)
print('Input for NN2', for_nn_2)
print('Input for NN3', for_nn_3)

nn_2_out = for_nn_2 + 2 # replace with NN2
nn_3_out = for_nn_3 + 10 # replace with NN3

# Here other sizes may differ when using actual networks
out = torch.zeros(batch_size, feat_size)

out.index_add_(0, nn2_ind, nn_2_out)
out.index_add_(0, nn3_ind, nn_3_out)
print('Output', out)
print('The values that used nn2 should be 2.** and the ones using nn3 should be 10.**')
2 Likes

Hi,
I didn’t try it yet. I will try it soon. Thanks for the reply!

Hi Alban, thanks for the reply! I will check this method and get back in case there are issues with it.

Hi Alban,
your technique works. I am wondering if the network also backpropagates through it?

If pytorch hasn’t complained about inplace operations, missing computation graphs or non-differentiable operations, then it is fairly safe to say that backpropagation works as expected.

Have you timed the various solutions?
I’d be interested to know which is fastest.

Hi, presently I haven’t timed them. I will do it soon though. Thanks

1 Like

Hi Alban,
it seems that the technique you suggested here, doesn’t work. I am having issues with index_add. The code that you provided works fine. Bu when I try to do something similar in my algorithm, it does not work. I could not find much information related to index_add. Is it a popularly used function?

Hi,
its not possible to backpropagate through the graphs! I get the following error
"NoneType object has no attribute data "

If I take Alban’s example, wrap x and y in Variables, and add out.sum().backward() at the end, then it backpropagates with no problem.

Unless you can show us your code and a full stack trace, then I don’t think anyone can help you track down your bug.

Hi,
so in your case, you are differentiating with respect to the inputs (say input images)? Isn’t that unconventional?
In my case, the graph is more complex. I use a NN to make decisions whether or not to use the next layer in MLP. However, when I try to do backpropagation, I get the error.
Thanks

The code examples work and will backpropagate correctly and unless you provide code I can’t help you any more than that.

If you don’t want to show your code can you at least produce a minimal code example that demonstrates the error?

1 Like

Hi,
sorry for that, my code is very long. I am working to produce a very simple working example. I will get back to you very soon.
Thanks

If I want to use this approach to conditional computation in the “central” part of my model it will lead to an autodiff issue.

The out tensor here is a Leaf node. In my case, the tensor where I split my input batch and then combine them is an intermediate node in the computational graph, so I am performing further operations on it. Hence I initialize it with a requires_grad=true flag.

As a result, when I try to use index_add_ it throws a RuntimeError:

a leaf Variable that requires grad is being used in an in-place operation.

What would be a workaround for it? This thread suggests making a clone or editing the data object of the variable directly: Leaf variable was used in an inplace operation

Would love to hear either of your feedbacks too @jpeg729 or @rahul since this thread has been around for quite a while!

Thanks in advance :slight_smile:

EDIT: Nevermind I used torch.Tensor.index_add (the out of place version). However, I am still facing issues with training as my loss does not seem to be changing. So I suspect backprop is not working right.

Hi,

It is completely ok to have out not require grad and then modify it inplace with something that does. It will make out require gradients and the gradients will propagate as expected :slight_smile:

Note that if you are not sure if backprop is working properly for a function (and the function is small enough), you can always try to use gradcheck.
Be aware though that you must use double precision, you function must be smooth at the point where you’re checking the gradients and that if your function is too big, the numerical evaluation of the gradient might be too imprecise and the test might fail for no reason.

Thanks for your feedback. In your example nn2_ind and nn2_ind will have requires_grad = False and as far as I can tell these tensors cannot be set to have requires_grad = True Could that potentially be leading to issues with autodiff?