Reshape some axes but leave all remaining axes unchanged

Is there a method to reshape some axes of a tensor but leave all the remaining axes unchanged, without specifying the number of remaining axes (should be dynamically inferred)?

Currently the user must specify the number of remaining axes, for every possible number of tensor dimensions; e.g. if(len(x.shape) == 4): pt.reshape(x, (x.shape[0]*x.shape[1], -1, -1)) elif(len...

For example;

  • pt.reshape(x, (x.shape[0]*x.shape[1], -2)), where -2 will dynamically infer the number of unchanged remaining dimensions.
  • pt.mergeAxes(x, (0, 1))

Yes, it is odd that PyTorch doesn’t provide better tensor shape manipulation functions natively. Maybe there is a good reason for this, I’m not sure.

There is einops which sort of does what you want, though I don’t think exactly:

from einops import rearrange, reduce, repeat
# rearrange elements according to the pattern
output_tensor = rearrange(input_tensor, 't b c -> b c t')
# combine rearrangement and reduction
output_tensor = reduce(input_tensor, 'b c (h h2) (w w2) -> b h w c', 'mean', h2=2, w2=2)
# copy along a new axis
output_tensor = repeat(input_tensor, 'h w -> h w c', c=3)