Activation Checkpointing Behavior with Branches

Hi all!

I’m trying to understand how activation checkpointing does or does not affect the order of gradient propagation during the backward pass, and I hope someone can help me :slight_smile:

To set up a contrived example:

class L1(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.l = Linear(64, 8192)

    def forward(self, x: Float[torch.Tensor, "B 64"]):
        print("l1")
        return self.l(x)

class L2(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.l = Linear(8192, 16)

    def forward(self, x: Float[torch.Tensor, "B 256"]):
        print("l2")
        return self.l(x)

class L3(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.l = Linear(16, 1)

    def forward(self, x: Float[torch.Tensor, "B 16"]):
        print("l3")
        return self.l(x)

class Wrap(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.l1 = L1()
        self.l2 = L2()
        self.l3 = L3()

    def forward(self, x: Float[torch.Tensor, "B 64"]):
        def forward_layers(x):
            y = self.l1(x)
            y = self.l2(y)
            y = self.l3(y)
            return y

        out = None
        for i in range(50):
            # y = checkpoint(forward_layers, x, use_reentrant=False)
            y = forward_layers(x)
            out = (out + y) if out is not None else y
        return out

What I observe is that if I pass this layer through a forward, trivial loss (sum), and then backward, the maximum memory consumption when I use checkpointing appears to remain constant regardless of how many branches I create (the size of the for loop), whereas when checkpointing is off, memory consumption increases linearly with the number parallel branches (as expected).

I’m basically trying to understand what it is in the backward pass that actually implements this, and whether this is intentional or if it’s just an artifact of how it happens to be implemented currently.

As a new user I only get two links, so I’ll post my refs here and reference them below:

I got as far as understanding:

  • The saved_tensors_hook in the non-reentrant checkpoint (checkpoint.py) and that the unpack hook triggers the recomputation on access (checkpoint.py#L1125).
  • I couldn’t quite figure out how this hooks into the C++ autograd engine. My best guess was that this object held the references to these tensors here (python_function.h), and there seem to be corresponding methods for executing the unpack. I had a bit of trouble following the types so couldn’t quite figure out how exactly this hooked into the C++ autograd engine execute, but it seems like these are lazily unpacked when they’re needed(?)

Assuming those are roughly “in the ballpark”, the question that I’m the most confused about is understanding the order in which the autograd engine goes through the graph. The things that I think are the case:

  • The autograd engine spins up one thread per device, but otherwise doesn’t do anything concurrently - e.g. for the 50 branches created by my for loop, if all of the tensors are on one device, will not be backprop-ed concurrently.
  • Torch does a topological sort by distance from the leaf nodes.
  • The starting set of nodes are populated with the “roots”.

The observation that I made where using checkpointing means that I can increase the loop_size with minimal impact to memory suggests to me that what’s happening is that torch is backprop-ing through an entire branch (l3 → l2 → l1) before proceeding to another, so that the activations can be discarded.

I think the backprop ordering is constrained by the topological sort (all dependencies need to be satisfied), but that it’s otherwise controlled by this sequence number (engine.cpp). However, I’m a bit lost on:

a. how/where this sequence number gets set to begin with, and
b. how the recomputation during the unpack hook (checkpoint.py#L1124-L1125) is able to get added to the computational graph (and where the sequence number would be set to get it to the top(?) of the priority queue). My understanding is that the non-reentrant checkpoint is specifically not invoking .backward() a second time, so would that mean it’s all part of the same GraphTask?

Am I on the right track here? I feel like I’m close and would really appreciate some pointers to help me get over the hump!

Hello,

You are deep into understanding how activation checkpointing affects memory usage and gradient computation in PyTorch. I’ll provide some insights and clarification on the key points related to checkpointing, the backward pass, and the order of gradient propagation.

Thanks for sharing this type of information.

Best regards!
Maxim Time Clock features