What is the meaning of trailing dimensions?

From the documentation:

>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300)
>>> b = torch.ones(22, 300)
>>> c = torch.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300])

Can be mutated to

>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300, 10)
>>> b = torch.ones(22, 300, 10)
>>> c = torch.ones(15, 300, 10)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300, 10])

to

>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300, 10, 5)
>>> b = torch.ones(22, 300, 10, 5)
>>> c = torch.ones(15, 300, 10, 5)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300, 10, 5])

to

>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300, 10, 5, 7)
>>> b = torch.ones(22, 300, 10, 5, 7)
>>> c = torch.ones(15, 300, 10, 5, 7)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300, 10, 5, 7])

I think you get the rough idea? * replaces (300,) or (300, 10) or (300, 10, 5) or (300, 10, 5, 7).

Best regards

Thomas

2 Likes