nn.Flatten() documentation seems wrong

The default end_dim is -1, which means the last time dimension is NOT flattened, however if you pass a torch.Size([3, 2, 4, 3]) to nn.Flatten() you get a torch.Size([3, 24]) where 24 is coming from flattening dimensions 1,2 AND 3 returning a tensor size [dim 0, dim 1 x dim 2 x dim 3]. I was expecting a tensor of shape [dim 0, dim 1 x dim 2, dim 3]. Am I reading the doc wrong?? seems pretty clear though and wrong.

As you can see in the documentation:


Which means the default is to flatten the tensor from the second dimension (start_dim=1) until the last dimension (end_dim= -1) and the first dimension (dim = 0) stays the same.

but end_dim = -1 in not the last dim, its the one before that, no? to be clear, nn.Flatten is doing what I want, I’m just saying the documentation is wrong OR confusing at least.

I think I misread the the equations… [d0, d1, d2, d3,…, dn] → dim -1 would be dn as in the doc

end_dim = -1 refers to the last dimension