Error: Use the torch.fx.symbolic_trace to trace the LSTM


I’m tring to use torch.fx.symbolic_trace to trace the nn.LSTM module.

However, I’m running into the following error:

torch.fx.proxy.TraceError:symbolically traced variables cannot be used as inputs to control flow .

Full code:

from torch.fx import Tracer
import torch.nn as nn
import torch

lstm = nn.LSTM(300, 100, 1)
x = torch.randn(7, 64, 300)
h = torch.randn(1, 64, 100)
c = torch.randn(1, 64, 100)
lstm(x, (h, c))

nodes = Tracer().trace(lstm).nodes