Debugging "Your training graph has changed in this iteration"

I have a “Reversible Vision Transformer”. It has a custom backwards pass, since it does not cache intermediate activations.

It seems to work fine for the most part, except for the following test case (2+ GPUs):

    print("Running more complicated DDP test ...")
    x1 = torch.randn((1, 3, 224, 224)).to(rank)
    x2 = torch.randn((1, 3, 224, 224)).to(rank)
    ddp_model = DDP(model, device_ids=[rank], static_graph=True)
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.0001)
    outputs1 = ddp_model(x1)
    outputs2 = ddp_model(x2)

    (outputs1 + outputs2).sum().backward()
    optimizer.step()
    print("Re-entrant DDP passed")

then, I get the error:

    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/vedant/miniconda3/envs/video-rec/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/vedant/Desktop/video-recommendation/model_contrastive/models/rev_vit.py", line 402, in ddp_test_helper
    (outputs1 + outputs2).sum().backward()
  File "/home/vedant/miniconda3/envs/video-rec/lib/python3.8/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/vedant/miniconda3/envs/video-rec/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/vedant/miniconda3/envs/video-rec/lib/python3.8/site-packages/torch/autograd/function.py", line 253, in apply
    return user_fn(self, *args)
  File "/home/vedant/Desktop/video-recommendation/model_contrastive/models/rev_vit.py", line 270, in backward
    X_1, X_2, dX_1, dX_2 = layer.backward_pass(
  File "/home/vedant/Desktop/video-recommendation/model_contrastive/models/rev_vit.py", line 201, in backward_pass
    g_Y_1.backward(dY_2, retain_graph=True)
  File "/home/vedant/miniconda3/envs/video-rec/lib/python3.8/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/vedant/miniconda3/envs/video-rec/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Your training graph has changed in this iteration, e.g., one parameter is unused in first iteration, but then got used in the second iteration. this is not compatible with static_graph set to True.

For context, the custom backward pass for a single “ReversibleBlock” in the transformer is:

    def backward_pass(
        self,
        Y_1,
        Y_2,
        dY_1,
        dY_2,
    ):
        print(f"Backwards: {self.layer_id} device_id={Y_1.get_device()}")
        """
        equation for activation recomputation:
        X_2 = Y_2 - G(Y_1), G = MLP
        X_1 = Y_1 - F(X_2), F = Attention
        """

        # temporarily record intermediate activation for G
        # and use them for gradient calculcation of G
        with torch.enable_grad():

            Y_1.requires_grad = True

            g_Y_1 = self.G(Y_1)
            assert g_Y_1.shape == Y_1.shape

            g_Y_1.backward(dY_2, retain_graph=True)

        # activation recomputation is by design and not part of
        # the computation graph in forward pass.
        with torch.no_grad():

            X_2 = Y_2 - g_Y_1
            del g_Y_1

            dY_1 = dY_1 + Y_1.grad
            Y_1.grad = None

        # record F activations and calc gradients on F
        with torch.enable_grad():
            X_2.requires_grad = True

            f_X_2 = self.F(X_2)

            # torch.manual_seed(self.seeds["droppath"])
            # f_X_2 = drop_path(
            #     f_X_2, drop_prob=self.drop_path_rate, training=self.training
            # )

            f_X_2.backward(dY_1, retain_graph=True)

        # propagate reverse computed acitvations at the start of
        # the previou block for backprop.s
        with torch.no_grad():
            X_1 = Y_1 - f_X_2

            del f_X_2, Y_1
            dY_2 = dY_2 + X_2.grad

            X_2.grad = None
            X_2 = X_2.detach()

        return X_1, X_2, dY_1, dY_2

My main question is: why am I getting this error? My back-propagation function has no if-statements / conditionals, so it should be executing the same thing every time. Is there something I’m missing about how backprop works?

Full code: Reversible VIT · GitHub
Original paper (my code is a cut/paste/simplification of their implementation): https://openaccess.thecvf.com/content/CVPR2022/papers/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.pdf