Simple extension of autograd saved tensor hook mechanism

I wrote two related topics before. If one of the devs listened, this could be a simple extension with quite some impact.

I found:

https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html

This is quite awesome, allows me to navigate along the forward-backward and modify (pack/unpack) tensor nodes along the way.

But it has one major shortcoming (unless I miss something): pack_hook(x) simply just receives the tensor to be packed. This allows for unspecified packing, like checkpointing. But many applications are specific to what node it was.

A node is created in the forward pass: x = f(…). If x is a tensor, this becomes a node in the autograd graph. The tensor x will eventually be passed as argument to pack_hook.

What I suggest is to be able to pass along some sort of ID. In the example above: “x_was_created_by_f”.

This would allow for specific packing. I might have some information depending on what f is, which would allow me to pack x. Without this info, I cannot.

I have a concrete use case of this, but it is complex to detail. It packs arguments using a specific recurrence, which can be run backwards.

I have this working, but it is very tedious at the moment. Since I cannot pass such an ID, I need to create some annotation when passing x = f(…), which contains some information required to recognize the tensor x later on. In pack_hook, I match x against all such annotations. This is error prone and really complex.

A really really simple extension of your mechanism could cut a lot of complexity and brittleness here for me, and make this great mechanism a lot more useful!

Thanks for listening. I love PyTorch, but it is difficult to get through with such proposals.

I thought this would nail it, but I am struggling now. This is because pack_hook only just gets the (detached) tensor, and I see no way to identify whether it is one of the x_k. While every new node in the graph gets a unique name, this is not passed to pack_hook.

I can match the tensors by shape, but I have a number of these sequences, and I need to be able to tell them apart. And for many arguments of pack_hook, I just want to pass them through.

Even better if somebody could suggest a way to do this even now, with the mechanism as is? I just do not see it.

To summary, at the point of creating the symbolic code x = f(…), I can create some kind of ID for x. I’d like this to be passed to pack_hook, which only takes a raw tensor. This would allow me to recognize that particular argument among all the ones pack_hook is called for.

My current solution is to store shape and dtype, plus a random part of the tensor content, and match it that way. This is brittle, because several different x could have the same such representation. It is also very complex.

Hi! I think it depends on the granularity at which you want to make the packing and unpacking specific:

  • Specific packing / unpacking function per Node instance (e.g. two different MulBackward0 nodes may have different packing and unpacking functions).
  • Specific packing / unpacking function per Node class (e.g. two different MulBackward0 nodes will always have the same packing and unpacking functions).
  • Specific packing / unpacking per nn.Module (e.g. all the Nodes created by the forward pass of a given nn.Linear layer will have the same packing and unpacking functions).

In the first two cases, I have no idea of an elegant solution.

In the third case, an idea would be to wrap each leaf module such that its forward pass is within a with torch.autograd.graph.saved_tensors_hooks(pack, unpack): context, that uses the appropriate pack and unpack function for it.

Hello, thanks for the advice. In my use case, it is not about a per-module solution, but the different nodes to pack are all jumbled up. To be more precise, I am working on long context fine-tuning, and the nodes in question are inputs to the scaled_dot_product_attention operator.

But even so: My understanding is that the whole forward-backward, up until loss.backward(), needs to be wrapped in a single saved_tensors_hooks context?

Let me check this. Your proposal could already make what I do more robust, if I can use different contexts for every layer of the model.

But even so: My understanding is that the whole forward-backward, up until loss.backward(), needs to be wrapped in a single saved_tensors_hooks context?

I don’t think so. In one of the examples from the tutorial you linked, they do this:

def pack_hook(x):
    print("Packing", x)
    return x

def unpack_hook(x):
    print("Unpacking", x)
    return x
a = torch.ones(5, requires_grad=True)
b = torch.ones(5, requires_grad=True) * 2

with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
    y = a * b

y.sum().backward()

So it seems that only the forward call needs to be wrapped on the saved_tensors_hooks context.

This makes me think that different contexts can be used for different parts of the forward pass.