Retrieving forward operations Dynamically

Hello,

I am trying to dynamically analyze the forward pass / operations performed on an input tensor X. Suppose we have the following code:

import torch
import torch.nn as nn


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # These can all be found using named_modules() or children()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, (1, 1), (2, 2), bias=True),
            nn.BatchNorm2d(16),
        )
        self.act = nn.SiLU()

    def forward(self, X):
        # we can retrieve these operations via .modules(), named_modules(), children(), etc.
        output = self.conv(X)
        output = self.act(output)
        # But not this
        final_output = torch.sigmoid(output)
        return final_output

To module evaluation is done below (for some reason, markdown is parsing the remaining code incorrectly, so I added it separately below)


if __name__ == '__main__':
    model = Model()

    for name, module in model.named_children():
        print(f'Name: {name}, module: {module}')

As noted in the code above, since the self.conv and self.act are nn.Module, we can retrieve these layers using model.named_children(). However, for torch.sigmoid, since the operation is stateless, I cant think of a way to obtain the forward pass information without converting the model into onnx.

Is there a way to do this in PyTorch?

Thank you!

I don’t know what exactly you are trying to do, but maybe torch.fx could help?

1 Like

@ptrblck Thank you it looks very promising. I will check it out right now.

I asked this question, because I am building a general conversion library that converts PyTorch models into Keras model (that output same values).

I have been moving weights and operations into the keras model by iterating over named_children() and writing appropriate converters, but I ran into a snag where stateless operations not extending nn.Modules inside the forward function would not be added to the conversion.

The final resort would be to parse the onnx computational graph operations one by one, but that can get pretty ugly, so I don’t want to resort to it.

@ptrblck I played around with it a little bit and this was exactly what I was looking for!. Thank you so much for taking time to respond :slight_smile:

1 Like