Is a sub-batch-tensor based on a discrete selection operation still differentiable when passed through multiple models?

If I have 3 sub-models in my end-to-end pipeline - A, B and C! And, I have an input of Batch Size of - 5 (any number - 5 for convenience) = IP.

Let’s say - I have 2 steps of operations:

  1. (IP)[5 x N]AOP1[5 x H] → B → OP2 [5 x 2]
    Here the input is sent through Model A and the output is then sent to Model B which gives a [BatchSize x 2] Output which can be used to obtain a Boolean Tensor of size BatchSize; let’s say : OP2 = [0,1,0,1,0]

Now, I want to use this tensor - OP2 to select a sub-tensor from OP1, which comes out of model A.

OP1_SubTensor = OP1[OP2.nonzero(), :]
  1. Ideally, I wanna use the OP1_SubTensor as input for Model C.

OP1_SubTensor [2 x H] → COP3 [2 x M]

I have 2 losses! One on the output of model B - Loss1 and one on the output of model C - Loss 2.

I am wondering if I can use this discrete selection on the output of model A (OP1) and then pass that sub-tensor to model C and then when I back prop on Loss1 and Loss 2; It will all be end-to-end optimized!

Or, should I filter the IP [5 x N] based on OP2 boolean tensor and pass it through the model A again?

Kinda confused!

So - upon checking other questions here and building on 2 replies -

@ptrblck How to Detach specific components in the loss? - #8 by ptrblck

and,
@albanD Is Loss.backward() function calculate gradients over mini-batch - #5 by albanD

It seems there are 2 possible approaches?

import torch
import torch.nn as nn

BS = 5 #batch size
modelA = nn.Linear(5,10)
modelB = nn.Linear(10, 2)
modelC = nn.Linear(10, 20)
x = torch.randn(BS, 5)

#Approach 1 - Clone to duplicate the computation Graph!
a = modelA(x)
a_copy = a.clone()
b = modelB(a_copy)
filter_index = torch.topk(b,1)[1].squeeze() #tensor([0, 1, 1, 1, 0])
a = a[filter_index.nonzero(),:].squeeze(1)
c = modelC(a)

print(modelA.weight.grad)
print(modelB.weight.grad)
print(modelC.weight.grad)

b.mean().backward(retain_graph = True)
#print(modelA.weight.grad)
#print(modelB.weight.grad)
#print(modelC.weight.grad)

c.mean().backward()
#print(modelA.weight.grad)
#print(modelB.weight.grad)
#print(modelC.weight.grad)


#Clear All:
modelA.weight.grad = None
modelB.weight.grad = None
modelC.weight.grad = None


#Approach 2 - Just retain_graph so it is not removed?
a = modelA(x)
b = modelB(a)
filter_index = torch.topk(b,1)[1].squeeze() #tensor([0, 1, 1, 1, 0])
a = a[filter_index.nonzero(),:].squeeze(1)
c = modelC(a)

print(modelA.weight.grad)
print(modelB.weight.grad)
print(modelC.weight.grad)

b.mean().backward(retain_graph = True)
#print(modelA.weight.grad)
#print(modelB.weight.grad)
#print(modelC.weight.grad)

c.mean().backward()
#print(modelA.weight.grad)
#print(modelB.weight.grad)
#print(modelC.weight.grad)

And, upon gradient checking, both seem to work fine! After the first .backward(), modelA and modelB is updated and the second .backward() updates modelA (further!) and modelC?

So, is there a potential difference between these 2 approaches? It seems - both require me to specify retain_graph. Unless I do something like this -

#Approach 3:
loss = b.mean() + c.mean()
loss.backward()

And this seems to work as well?! Without the retain_graph since the computation graph is combined (maybe?).

So - basically - Are all 3 approaches the same? and neither is better than the other? And, the a = a[filter_index.nonzero(),:].squeeze(1) step to discretely and partially filter the modelA output-batch poses no issue in breaking the computation graph?

PS: I checked this reply about memory efficiency using batch gradients accumulation here - Using sub-batches to avoid busting the memory? - #2 by albanD and if multiple batches can be evaluated (.backward()) for a single optimizer-step to accumulate gradients then a part of the batch output can be used to evaluate the model as well right?