I am getting an error when I try to compute model FLOPs with the torcheval module_summary.
Specifically, the issue seems to be that module_summary
calls the forward function on every submodule of my model (see here), unpacking the provided args and passing them through.
Since one of my submodules is an nn.LSTM
that I pass a PackedSequence
to, and a pytorch PackedSequence
inherits from a named tuple, this gets unpacked (sorry this term is overloaded, I mean unpacked as an iterable into args to the LSTM forward function), and errors.
example:
embedded = self.embeddings(source)
# Packs embedded source symbols into a PackedSequence.
packed = nn.utils.rnn.pack_padded_sequence(
embedded, lengths, batch_first=True, enforce_sorted=False
)
# -> B x seq_len x encoder_dim, (h0, c0).
packed_outs, (H, C) = self.encoder(packed)
encoded, _ = nn.utils.rnn.pad_packed_sequence(
packed_outs,
batch_first=True,
padding_value=self.pad_idx,
total_length=None,
)
Causes TypeError: forward() takes from 2 to 3 positional arguments but 5 were given
while profiling FLOPs.
Of course, I can simply pass packed
inside an iterable, which solves the issue when profiling:
packed_outs, (H, C) = self.encoder((packed,))
but then the actual forward pass errors when I am not profiling with torch eval, since it expects a PackedSequence
: AttributeError: 'tuple' object has no attribute 'dim'
This seems like a really silly issue that someone else might have run into. I am probably being dense but is there a simple solution to this?