I’m trying to train a model that’s composed of two submodules, B and D, on input I.
Let X_b be the intermediate feature vectors after applying B to I. (X_b = B(I))
Now, D is applied on X_b to get the prediction, on which loss L_s is computed.

What I’m trying to do is as follows:

Step 1: do the above-mentioned forward pass through the entire model, get the loss and propagate gradients backward throughout both B and D.

Step 2: Apply a augmentation/permutation f() on X_b to get X_b’, do a forward call on the submodule D with X_b’ and compute the same loss, but only update either B or D.

I’m attaching an outline sketch of the process. Solid lines of the same colour represent a forward, and dashed lines represent backward gradient updates.

What is the best way to achieve this? I’ve tried quite a few things, including trying to set requires_grad=False for submodule D after the Step 1, but all permutations of things I have tried end up not updating the weights of D at all.

You don’t use the word “optimize” in your description, but I assume that you
want to backpropagate the loss L_s in order to populate the gradients of B
and D and then take an optimization step.

You make the logic a little more complicated by saying in Step 2 that you
want to only update either B or D.

Let’s start with the case where you update D in Step 2:

Perform your Step 1 forward pass, computing L_s and keeping a reference
to X_b. Make a copy of X_b, X_c = X_b.detach().clone(), pass f (X_c)
through D, and compute a second value for the loss L_s. Sum the two
losses together and backpropagate.

The will populate B with gradients due to the Step 1 forward pass and will
populate D with the sum of its Step 1 and Step 2 forward-pass gradients.
You can now optimize with opt_B.step() and opt_D.step().

For the case where you update B in Step 2:

Perform the forward pass just through B. Keep a reference to X_b Pass f (X_b) through D (no .detach() nor .clone()) and compute L_s. Call L_s (retain_graph = True). This populates B with its Step 2 gradients.
(It also populates D with undesired gradients.)

Now call opt_D.zero_grad(). This zeros out the Step 2 gradients for D
that you said you didn’t want. Complete the Step 1 forward pass by passing X_b (to which you kept a reference) through D, compute L_s, and call L_s.backward(). This populates D’s Step 1 gradients and accumulates
B’s Step 1 gradients onto its Step 2 gradients.