From the documentation:
>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300)
>>> b = torch.ones(22, 300)
>>> c = torch.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300])
Can be mutated to
>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300, 10)
>>> b = torch.ones(22, 300, 10)
>>> c = torch.ones(15, 300, 10)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300, 10])
to
>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300, 10, 5)
>>> b = torch.ones(22, 300, 10, 5)
>>> c = torch.ones(15, 300, 10, 5)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300, 10, 5])
to
>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300, 10, 5, 7)
>>> b = torch.ones(22, 300, 10, 5, 7)
>>> c = torch.ones(15, 300, 10, 5, 7)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300, 10, 5, 7])
I think you get the rough idea? *
replaces (300,)
or (300, 10)
or (300, 10, 5)
or (300, 10, 5, 7)
.
Best regards
Thomas