Hello,
I’m tring to use torch.fx.symbolic_trace
to trace the nn.LSTM
module.
However, I’m running into the following error:
torch.fx.proxy.TraceError:symbolically traced variables cannot be used as inputs to control flow
.
Full code:
from torch.fx import Tracer
import torch.nn as nn
import torch
lstm = nn.LSTM(300, 100, 1)
x = torch.randn(7, 64, 300)
h = torch.randn(1, 64, 100)
c = torch.randn(1, 64, 100)
lstm(x, (h, c))
nodes = Tracer().trace(lstm).nodes