Do items in lists have to consist of only a single type?

I am trying to trace nvidia’s tacotron 2 model and interface with it via. the C++ frontend.

Running the traced function via. the Python frontend works just fine, and reports results as expected.

Through the C++ frontend however, it complains of a list of weights being fed into tacotron’s decoder LSTM not being of equal size (as each weight parameter is of different size/type).

The code used to trace and export out the model is as follows (nvidia’s implementation is linked here: GitHub - NVIDIA/tacotron2: Tacotron 2 - PyTorch implementation with faster-than-realtime inference):

import numpy as np
import torch

from hparams import create_hparams
from text import text_to_sequence
from train import load_model

hparams = create_hparams()
hparams.sampling_rate = 22050

tacotron = load_model(hparams)
tacotron.load_state_dict(torch.load("tacotron2_statedict.pt", map_location='cpu')['state_dict'])
tacotron.eval()

print(tacotron)

text = "This is some random text."
sequence = np.array(text_to_sequence(text, ['english_cleaners']))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()

traced_tacotron = torch.jit.trace(tacotron.inference, sequence, optimize=False, check_trace=False)
traced_tacotron.save("tacotronzzz.pt")

print(tacotron.inference(sequence))

Here is the C++ frontend code:

#include <iostream>
#include <torch/script.h>
#include <torch/torch.h>

using namespace std;

int main() {
    shared_ptr<torch::jit::script::Module> tacotron = torch::jit::load("tacotronzzz.pt");

    assert(tacotron != nullptr);

    return 0;
}

If it helps, I can also provide a download to tacotronzzz.pt.

Any help is much appreciated; if this is actually a bug, I’m happy to dig into the internals of libtorch and see if this could be fixed in any way.

Just to add some reference links regarding how Lists in the IR are interpreted:


This in general seems to be a problem with tracing LSTM.

Here is the class traced:

class Encoder(nn.Module):
    """Encoder module:
        - Three 1-d convolution banks
        - Bidirectional LSTM
    """
    def __init__(self, hparams):
        super(Encoder, self).__init__()

        convolutions = []
        for _ in range(hparams.encoder_n_convolutions):
            conv_layer = nn.Sequential(
                ConvNorm(hparams.encoder_embedding_dim,
                         hparams.encoder_embedding_dim,
                         kernel_size=hparams.encoder_kernel_size, stride=1,
                         padding=int((hparams.encoder_kernel_size - 1) / 2),
                         dilation=1, w_init_gain='relu'),
                nn.BatchNorm1d(hparams.encoder_embedding_dim))
            convolutions.append(conv_layer)
        self.convolutions = nn.ModuleList(convolutions)

        self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
                            int(hparams.encoder_embedding_dim / 2), 1,
                            batch_first=True, bidirectional=True)

    def forward(self, x, input_lengths):
        for conv in self.convolutions:
            x = F.dropout(F.relu(conv(x)), 0.5, self.training)

        x = x.transpose(1, 2)

        # pytorch tensor are not reversible, hence the conversion
        input_lengths = input_lengths.cpu().numpy()
        x = nn.utils.rnn.pack_padded_sequence(
            x, input_lengths, batch_first=True)

        self.lstm.flatten_parameters()
        outputs, _ = self.lstm(x)

        outputs, _ = nn.utils.rnn.pad_packed_sequence(
            outputs, batch_first=True)

        return outputs

    def inference(self, x):
        for conv in self.convolutions:
            x = F.dropout(F.relu(conv(x)), 0.5, self.training)

        x = x.transpose(1, 2)

        self.lstm.flatten_parameters()
        outputs, _ = self.lstm(x)

        return outputs

This yields the IR which makes use of lists whose elements are of multiple types (each element is a Float tensor, though of different shape/size).

%hx.1 : Float(2, 1, 256) = aten::zeros(%105, %106, %107, %108), scope: LSTM
  %146 : Tensor[] = prim::ListConstruct(%hx.1, %hx.1), scope: LSTM
  %147 : Float(1024, 512) = prim::Constant[value=<Tensor>](), scope: LSTM
  %148 : Float(1024, 256) = prim::Constant[value=<Tensor>](), scope: LSTM
  %149 : Float(1024) = prim::Constant[value=<Tensor>](), scope: LSTM
  %150 : Float(1024) = prim::Constant[value=<Tensor>](), scope: LSTM
  %151 : Float(1024, 512) = prim::Constant[value=<Tensor>](), scope: LSTM
  %152 : Float(1024, 256) = prim::Constant[value=<Tensor>](), scope: LSTM
  %153 : Float(1024) = prim::Constant[value=<Tensor>](), scope: LSTM
  %154 : Float(1024) = prim::Constant[value=<Tensor>](), scope: LSTM
  %155 : Tensor[] = prim::ListConstruct(%147, %148, %149, %150, %151, %152, %153, %154), scope: LSTM
  %156 : bool = prim::Constant[value=1](), scope: LSTM
  %157 : int = prim::Constant[value=1](), scope: LSTM
  %158 : float = prim::Constant[value=0](), scope: LSTM
  %159 : bool = prim::Constant[value=0](), scope: LSTM
  %160 : bool = prim::Constant[value=1](), scope: LSTM
  %161 : bool = prim::Constant[value=1](), scope: LSTM
  %memory : Float(1!, 41, 512), %163 : Float(2, 1, 256), %164 : Float(2, 1, 256) = aten::lstm(%input.14, %146, %155, %156, %157, %158, %159, %160, %161), scope: LSTM

whoops, missed this somehow, sorry for the late reply. We have made a change in master that allows lists to hold tensors of different shapes/sizes. If you try a nightly build it should work for you