The default end_dim is -1, which means the last time dimension is NOT flattened, however if you pass a torch.Size([3, 2, 4, 3]) to nn.Flatten() you get a torch.Size([3, 24]) where 24 is coming from flattening dimensions 1,2 AND 3 returning a tensor size [dim 0, dim 1 x dim 2 x dim 3]. I was expecting a tensor of shape [dim 0, dim 1 x dim 2, dim 3]. Am I reading the doc wrong?? seems pretty clear though and wrong.
As you can see in the documentation:
Parameters
- start_dim – first dim to flatten (default = 1).
- end_dim – last dim to flatten (default = -1).
(Flatten — PyTorch 2.1 documentation)
Which means the default is to flatten the tensor from the second dimension (start_dim=1) until the last dimension (end_dim= -1) and the first dimension (dim = 0) stays the same.
but end_dim = -1 in not the last dim, its the one before that, no? to be clear, nn.Flatten is doing what I want, I’m just saying the documentation is wrong OR confusing at least.
I think I misread the the equations… [d0, d1, d2, d3,…, dn] → dim -1 would be dn as in the doc
end_dim = -1 refers to the last dimension