I am looking for an elegant way to flatten a tensor of arbitrary shape to a matrix based on a single parameter that specifies the dimension to retain. For illustration, I would like

def my_func(input, dim):
# code to compute output
return output

Given for example an input tensor of shape 2x3x4, output should be for dim=0 a tensor of shape 12x2; for dim=1 a tensor of shape 8x3; for dim=2 a tensor of shape 6x8. If I want to flatten the last dimension only, then this is easily accomplished by

input.view(-1, input.shape[-1])

But I would like to add the functionality of adding dim (elegantly, without going through all possible cases + checking with if conditions, etc.). It might be possible by first swapping dimensions, so that the dimension of interest is trailing and then applying the operation above.

I am not sure how you expect that flattening to happen but this will swap the selected dimension with the last one and flatten all the first dimensions using pytorch’s row major layout.

Many thanks!
This is actually what I need.
The ordering of the rows in output is not important to me and your code selects the points correctly from input by basically flattening all dimensions other than dim.

As a side note, is the call to contiguous() expensive or do I not need to be concerned?