Unable to use torcheval `module_summary` with packed sequences for LSTM

I am getting an error when I try to compute model FLOPs with the torcheval module_summary.

Specifically, the issue seems to be that module_summary calls the forward function on every submodule of my model (see here), unpacking the provided args and passing them through.

Since one of my submodules is an nn.LSTM that I pass a PackedSequence to, and a pytorch PackedSequence inherits from a named tuple, this gets unpacked (sorry this term is overloaded, I mean unpacked as an iterable into args to the LSTM forward function), and errors.

example:

        embedded = self.embeddings(source)
        # Packs embedded source symbols into a PackedSequence.
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, lengths, batch_first=True, enforce_sorted=False
        )
        # -> B x seq_len x encoder_dim, (h0, c0).
        packed_outs, (H, C) = self.encoder(packed)
        encoded, _ = nn.utils.rnn.pad_packed_sequence(
            packed_outs,
            batch_first=True,
            padding_value=self.pad_idx,
            total_length=None,
        )

Causes TypeError: forward() takes from 2 to 3 positional arguments but 5 were given while profiling FLOPs.

Of course, I can simply pass packed inside an iterable, which solves the issue when profiling:

packed_outs, (H, C) = self.encoder((packed,))

but then the actual forward pass errors when I am not profiling with torch eval, since it expects a PackedSequence: AttributeError: 'tuple' object has no attribute 'dim'

This seems like a really silly issue that someone else might have run into. I am probably being dense but is there a simple solution to this?

I would recommend posting this issue in the torcheval GitHub repository to track and fix it, as it’s still in its alpha stage.

1 Like