When creating a variable-length PackedSequence with batch_first=True
, accessing the .data
attribute returns the sequences out of order, as if batch_first=False
.
I don’t really understand the reasoning behind having the sequence be the first dimension by default (it seems less intuitive given how pytorch otherwise deals with batches), but I’m assuming it is for performance reasons. Even then, given that the .data
attribute is public-facing, I feel like it should be returned in the same order as it was given. Then, for those of us writing modules that use padded and packed sequences, we can more naturally deal with this input without hacking together even more re-ordering code; requiring pack_padded_sequence()
and pack_sequence()
to receive sequences in decreasing length is enough of a hassle, but that’s another topic.
I wasn’t sure if this behavior was intended or not, so I’m posting here rather than making a bug report. But is this behavior correct? If so, why, and how does pytorch recommend dealing with this issue?
Code to Reproduce Behavior:
import torch
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
batch_first_seqs = [ \
torch.rand((3, 2)),
torch.rand((2, 2)),
torch.rand((1, 2))]
lengths = torch.LongTensor([3, 2, 1])
padded_seqs = pad_sequence(batch_first_seqs, batch_first=True)
packed_seqs = pack_padded_sequence(padded_seqs, lengths=lengths, batch_first=True)
print(batch_first_seqs)
print(padded_seqs)
print(packed_seqs)
print(torch.cat(batch_first_seqs) == packed_seqs.data)
Output:
[tensor([[0.7967, 0.5329],
[0.6376, 0.3543],
[0.6514, 0.8007]]), tensor([[0.1709, 0.1577],
[0.5007, 0.8083]]), tensor([[0.3345, 0.7590]])]
tensor([[[0.7967, 0.5329],
[0.6376, 0.3543],
[0.6514, 0.8007]],
[[0.1709, 0.1577],
[0.5007, 0.8083],
[0.0000, 0.0000]],
[[0.3345, 0.7590],
[0.0000, 0.0000],
[0.0000, 0.0000]]])
PackedSequence(data=tensor([[0.7967, 0.5329],
[0.1709, 0.1577],
[0.3345, 0.7590],
[0.6376, 0.3543],
[0.5007, 0.8083],
[0.6514, 0.8007]]), batch_sizes=tensor([3, 2, 1]))
tensor([[1, 1],
[0, 0],
[0, 0],
[0, 0],
[1, 1],
[0, 0]], dtype=torch.uint8)