Works like expected. Thanks a ton!
import torch
import torch.nn as nn
from torch.overrides import TorchFunctionMode
import pdb
current_func = None
current_args = None
class SetCurrentFunc(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
global current_func, current_args
if not kwargs:
kwargs = {}
current_func = func
current_args = args
out = func(*args, **kwargs)
current_func = None
return out
model = nn.TransformerEncoderLayer(128, 2)
print(model)
def pack_hook(x):
print("pck", x.data_ptr(), "in", current_func, "shape", x.shape, x.dtype)
return x
def unpack_hook(x):
return x
torch._C._autograd._push_saved_tensors_default_hooks(pack_hook, unpack_hook)
x = torch.randn(10,32,128)
y = torch.randn(10,32,128)
with SetCurrentFunc():
yhat = model(x)
loss = nn.functional.mse_loss(yhat, y)
loss.backward()
Output:
TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
)
(linear1): Linear(in_features=128, out_features=2048, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=2048, out_features=128, bias=True)
(norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
pck 104910528 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([320, 128]) torch.float32
pck 105243840 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([]) torch.float64
pck 105244288 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([64, 10, 64]) torch.float32
pck 106081024 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([64, 64, 10]) torch.float32
pck 105439040 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([64, 10, 10]) torch.float32
pck 105465152 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([64, 10, 10]) torch.float32
pck 105490880 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([64, 10, 10]) torch.float32
pck 106246400 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([64, 10, 64]) torch.float32
pck 104779776 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([128, 128]) torch.float32
pck 106412032 in <function multi_head_attention_forward at 0x7f74c3b661f0> shape torch.Size([320, 128]) torch.float32
pck 106739968 in <function dropout at 0x7f74c3bd8310> shape torch.Size([10, 32, 128]) torch.float32
pck 107067904 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([10, 32, 128]) torch.float32
pck 104894336 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([128]) torch.float32
pck 104895488 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([128]) torch.float32
pck 105416896 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([10, 32, 1]) torch.float32
pck 105418304 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([10, 32, 1]) torch.float32
pck 140139455070272 in <built-in function linear> shape torch.Size([128, 2048]) torch.float32
pck 106903936 in <built-in function linear> shape torch.Size([320, 128]) torch.float32
pck 140138382893120 in <function relu at 0x7f74c3bd8700> shape torch.Size([10, 32, 2048]) torch.float32
pck 107231872 in <function dropout at 0x7f74c3bd8310> shape torch.Size([10, 32, 2048]) torch.float32
pck 140139454017600 in <built-in function linear> shape torch.Size([2048, 128]) torch.float32
pck 109853440 in <built-in function linear> shape torch.Size([320, 2048]) torch.float32
pck 112638976 in <function dropout at 0x7f74c3bd8310> shape torch.Size([10, 32, 128]) torch.float32
pck 112475008 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([10, 32, 128]) torch.float32
pck 104896576 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([128]) torch.float32
pck 104897792 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([128]) torch.float32
pck 106037312 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([10, 32, 1]) torch.float32
pck 106038720 in <function layer_norm at 0x7f74c3bdb790> shape torch.Size([10, 32, 1]) torch.float32
pck 112802944 in <function mse_loss at 0x7f74c3bdbf70> shape torch.Size([10, 32, 128]) torch.float32
pck 105074496 in <function mse_loss at 0x7f74c3bdbf70> shape torch.Size([10, 32, 128]) torch.float32
Also, thanks for answering follow-up questions. That is reassuring.