Capture from which module a function was dispatched

Hello,

Imagine you have a basic BERT model:

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(20, 32, padding_idx=0)
      (position_embeddings): Embedding(128, 32)
      (token_type_embeddings): Embedding(2, 32)
      (LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=32, out_features=32, bias=True)
              (key): Linear(in_features=32, out_features=32, bias=True)
              (value): Linear(in_features=32, out_features=32, bias=True)
              (dropout): Dropout(p=0, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=32, out_features=32, bias=True)
              (LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=32, out_features=32, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=32, out_features=32, bias=True)
            (LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=32, out_features=32, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0, inplace=False)
  (classifier): Linear(in_features=32, out_features=2, bias=True)
)

Now during the backward pass, all kinds of functions are being called/dispatched. I wrote myself my own tensor subclass to capture all the dispatched functions on my backward pass (the number is the amount of calls):

{
 aten.select_backward.default : 3
 aten.t.default : 24
 aten.sum.dim_IntList : 8
 aten.view.default : 30
 aten.detach.default : 50
 aten.tanh_backward.default : 1
 aten.slice_backward.default : 1
 aten.native_layer_norm_backward.default : 3
 aten.gelu_backward.default : 1
 aten.maximum : 4
 aten.permute.default : 4
 aten.bmm.default : 4
 aten._softmax_backward_data.default : 1
 aten.div.Tensor : 1
 aten.transpose.int : 1
 aten.embedding_dense_backward.default : 3
}

Now I wanna figure out in which nn.Module and in which “Bert”-Module a specific function call was made.

I first thought I can just add a hook to each module and assume my data passes through my model sequentially. Thus I can always get the “current module” and that way figure out which module lead to the dispatched function. But that might not always be true.

I then saw there are tools like fx or the autograd profiler but I’m not sure if those are suitable tools for my needs.

I’m only interested in the backward pass at the moment but feel free to also suggest any solution that might work for both, forward and backward pass.

Thanks in advance