Capturing the computation graph

Hello there,

Is there a way to capture a computation graph of an nn.Module other than torch.fx? The issue I am having with torch.fx is that when the forward function contains conditionals, static capture doesn’t seem to work. I am trying to obtain the computation graph that is identical to that which torch.fx enables us to obtain but also want it to work for forward functions that have conditionals in them. For instance, I can capture the computation graph of

nn.TransformerEncoderLayer

but not

nn.Transformer

using torch.fx.

1 Like