Hello there,
Is there a way to capture a computation graph of an nn.Module other than torch.fx? The issue I am having with torch.fx is that when the forward function contains conditionals, static capture doesn’t seem to work. I am trying to obtain the computation graph that is identical to that which torch.fx enables us to obtain but also want it to work for forward functions that have conditionals in them. For instance, I can capture the computation graph of
nn.TransformerEncoderLayer
but not
nn.Transformer
using torch.fx.