How to have different saved_tensor_hooks for different functions

when i use “function” below i mean nodes that are present in autograd graph.

Goal : specify how to pack and unpack saved tensors differently for different functional nodes. i.e. different for say a GeLU and matrixmultiplication functions.

I see in the tutorials that i can define pack_hook and unpack_hook which are globally defined for all tensors. The only input to this function is the tensor itself. Is there a way to identify which function it belongs to?

I also read that I can register pack and unpack hooks on individual tensors. However when should i do this? The autograd graph is constructed as individual forwards are completed. If i try to register pack/ unpack hooks at the end of forward ( by using forward_hook on the module) then it says the hooks can only be set once and they are already set.

  • Any guidance would help. If using pack / unpack hooks is not the correct way what is?
  • I believe i can always define custom functions like customLinear for Linear and define what to store for backward using saved_for_backward. However that requires me to stich a new model from existing model replacing Module (like Linear) with custom modules (like customLinear). This seems like a round about way to doing things. So i am not sure if should just do this.

You can’t directly get the backward node for which the forward is saving unfortunately. But there’s a roundabout way to insert logic about what the current node is running.

First I use a TorchFunctionMode to interpose at every call and sets a global variable to point to the current running function.

import torch
from torch.overrides import TorchFunctionMode

current_func = None

class SetCurrentFunc(TorchFunctionMode):
    def __torch_function__(self, func, types, args=(), kwargs=None):
        global current_func
        if not kwargs:
            kwargs = {}
        current_func = func
        out = func(*args, **kwargs)
        current_func = None
        return out

Saved tensor hooks then is able to have access to that variable and do different things in pack/unpack depending on what the current function running is.

from torch.autograd.graph import saved_tensors_hooks

def pack(x):
    print(current_func)
    if current_func in (torch.sin, torch.Tensor.sin):
        return torch.zeros_like(x)
    else:
        return x

a = torch.tensor(3.1415 / 2., requires_grad=True)
with SetCurrentFunc(), saved_tensors_hooks(pack, lambda x: x):
    b = a.sin()
    b.backward()
    print(a.grad)  # grad_out * cos(0.) = 1

    a.grad = None
    c = a.cos()
    c.backward()  # grad_out * -sin(pi/2) = -1
    print(a.grad)

1 Like

Thanks for prompt response. This is immensely useful. I have a follow-up question / concern
1. When running the forward, is it possible that different parts of the graph are running simultaneously? Like two functions are running asynchronously ?
2. If 1 is fine, then will this piece of code also work with torch DP and DDP?

In the meantime, let me try this solution and see if it works for me. Thanks again for the answer. Immensely useful.

No problem.

  1. I don’t think so, unless you try to explicitly do some multithreading?
  2. Seems fine since each shard of the data is processed on a different process

Works like expected. Thanks a ton!

import torch
import torch.nn as nn
from torch.overrides import TorchFunctionMode
import pdb

current_func = None
current_args = None

class SetCurrentFunc(TorchFunctionMode):
    def __torch_function__(self, func, types, args=(), kwargs=None):
        global current_func, current_args
        if not kwargs:
            kwargs = {}
        current_func = func
        current_args = args
        out = func(*args, **kwargs)
        current_func = None
        return out


model = nn.TransformerEncoderLayer(128, 2)
print(model)

def pack_hook(x):
    print("pck", x.data_ptr(), "in", current_func, "shape", x.shape, x.dtype)
    return x

def unpack_hook(x):
    return x

torch._C._autograd._push_saved_tensors_default_hooks(pack_hook, unpack_hook)
x = torch.randn(10,32,128)
y = torch.randn(10,32,128)

with SetCurrentFunc():
    yhat = model(x)
    loss = nn.functional.mse_loss(yhat, y)
    loss.backward()

Output:

TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
  )
  (linear1): Linear(in_features=128, out_features=2048, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=2048, out_features=128, bias=True)
  (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
)
pck 104910528 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([320, 128]) torch.float32
pck 105243840 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([]) torch.float64
pck 105244288 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([64, 10, 64]) torch.float32
pck 106081024 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([64, 64, 10]) torch.float32
pck 105439040 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([64, 10, 10]) torch.float32
pck 105465152 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([64, 10, 10]) torch.float32
pck 105490880 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([64, 10, 10]) torch.float32
pck 106246400 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([64, 10, 64]) torch.float32
pck 104779776 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([128, 128]) torch.float32
pck 106412032 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([320, 128]) torch.float32
pck 106739968 in <function dropout at 0x7f74c3bd8310> shape torch.Size([10, 32, 128]) torch.float32
pck 107067904 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([10, 32, 128]) torch.float32
pck 104894336 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([128]) torch.float32
pck 104895488 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([128]) torch.float32
pck 105416896 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([10, 32, 1]) torch.float32
pck 105418304 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([10, 32, 1]) torch.float32
pck 140139455070272 in <built-in function linear> shape torch.Size([128, 2048]) torch.float32
pck 106903936 in <built-in function linear> shape torch.Size([320, 128]) torch.float32
pck 140138382893120 in <function relu at 0x7f74c3bd8700> shape torch.Size([10, 32, 2048]) torch.float32
pck 107231872 in <function dropout at 0x7f74c3bd8310> shape torch.Size([10, 32, 2048]) torch.float32
pck 140139454017600 in <built-in function linear> shape torch.Size([2048, 128]) torch.float32
pck 109853440 in <built-in function linear> shape torch.Size([320, 2048]) torch.float32
pck 112638976 in <function dropout at 0x7f74c3bd8310> shape torch.Size([10, 32, 128]) torch.float32
pck 112475008 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([10, 32, 128]) torch.float32
pck 104896576 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([128]) torch.float32
pck 104897792 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([128]) torch.float32
pck 106037312 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([10, 32, 1]) torch.float32
pck 106038720 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([10, 32, 1]) torch.float32
pck 112802944 in <function mse_loss at 0x7f74c3bdbf70> shape torch.Size([10, 32, 128]) torch.float32
pck 105074496 in <function mse_loss at 0x7f74c3bdbf70> shape torch.Size([10, 32, 128]) torch.float32

Also, thanks for answering follow-up questions. That is reassuring.

1 Like