Error: Use the torch.fx.symbolic_trace to trace the LSTM

Hello,

I’m tring to use torch.fx.symbolic_trace to trace the nn.LSTM module.

However, I’m running into the following error:

torch.fx.proxy.TraceError:symbolically traced variables cannot be used as inputs to control flow .

Full code:

from torch.fx import Tracer
import torch.nn as nn
import torch

lstm = nn.LSTM(300, 100, 1)
x = torch.randn(7, 64, 300)
h = torch.randn(1, 64, 100)
c = torch.randn(1, 64, 100)
lstm(x, (h, c))

nodes = Tracer().trace(lstm).nodes

Hello,
@MartinZhang do you find a solution to this issue?

I am trying to do the same and it is not possible to update the states of the LSTM when trying to trace using torch fx.

Thanks,

Hi,

I custom the LSTM and trace it.

note the seq_sz parameter, please.

you need forward this Module first, to save the seq_sz and trace it.

if some function can’t trace, you can use the is_leaf_module to finished!

class CustomLSTM(nn.Module):
    def __init__(self, input_sz, hidden_sz):
        super().__init__()
        self.input_sz = input_sz
        self.hidden_size = hidden_sz
        self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
        self.init_weights()
        self.seq_sz = 0
 
    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)
 
    def forward(self, x, init_states=None):
        """Assumes x is of shape (batch, sequence, feature)"""
        if not isinstance(x, toch.fx.Proxy):
            bs, self.seq_sz, _ = x.size()
        
        bs, _, _ = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device), 
                        torch.zeros(bs, self.hidden_size).to(x.device))
        else:
            h_t, c_t = init_states
 
        HS = self.hidden_size
        for t in range(self.seq_sz):
            x_t = x[:, t, :]
            # batch the computations into a single matrix multiplication
            gates = x_t @ self.W + h_t @ self.U + self.bias
            i_t, f_t, g_t, o_t = (
                torch.sigmoid(gates[:, :HS]), # input
                torch.sigmoid(gates[:, HS:HS*2]), # forget
                torch.tanh(gates[:, HS*2:HS*3]),
                torch.sigmoid(gates[:, HS*3:]), # output
            )
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t.unsqueeze(0))
        hidden_seq = torch.cat(hidden_seq, dim=0)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        return hidden_seq, (h_t, c_t)

I hope this can help you.