Conditional computation that saves computation

If I want to use this approach to conditional computation in the “central” part of my model it will lead to an autodiff issue.

The out tensor here is a Leaf node. In my case, the tensor where I split my input batch and then combine them is an intermediate node in the computational graph, so I am performing further operations on it. Hence I initialize it with a requires_grad=true flag.

As a result, when I try to use index_add_ it throws a RuntimeError:

a leaf Variable that requires grad is being used in an in-place operation.

What would be a workaround for it? This thread suggests making a clone or editing the data object of the variable directly: Leaf variable was used in an inplace operation

Would love to hear either of your feedbacks too @jpeg729 or @rahul since this thread has been around for quite a while!

Thanks in advance :slight_smile:

EDIT: Nevermind I used torch.Tensor.index_add (the out of place version). However, I am still facing issues with training as my loss does not seem to be changing. So I suspect backprop is not working right.