I wrote two related topics before. If one of the devs listened, this could be a simple extension with quite some impact.
I found:
https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html
This is quite awesome, allows me to navigate along the forward-backward and modify (pack/unpack) tensor nodes along the way.
But it has one major shortcoming (unless I miss something): pack_hook(x) simply just receives the tensor to be packed. This allows for unspecified packing, like checkpointing. But many applications are specific to what node it was.
A node is created in the forward pass: x = f(…). If x is a tensor, this becomes a node in the autograd graph. The tensor x will eventually be passed as argument to pack_hook.
What I suggest is to be able to pass along some sort of ID. In the example above: “x_was_created_by_f”.
This would allow for specific packing. I might have some information depending on what f is, which would allow me to pack x. Without this info, I cannot.
I have a concrete use case of this, but it is complex to detail. It packs arguments using a specific recurrence, which can be run backwards.
I have this working, but it is very tedious at the moment. Since I cannot pass such an ID, I need to create some annotation when passing x = f(…), which contains some information required to recognize the tensor x later on. In pack_hook, I match x against all such annotations. This is error prone and really complex.
A really really simple extension of your mechanism could cut a lot of complexity and brittleness here for me, and make this great mechanism a lot more useful!
Thanks for listening. I love PyTorch, but it is difficult to get through with such proposals.
I thought this would nail it, but I am struggling now. This is because pack_hook only just gets the (detached) tensor, and I see no way to identify whether it is one of the x_k. While every new node in the graph gets a unique name, this is not passed to pack_hook.
I can match the tensors by shape, but I have a number of these sequences, and I need to be able to tell them apart. And for many arguments of pack_hook, I just want to pass them through.