I want to create a function that given a tensor shape, returns its equivalent index tensor. Basically, I want to replicate a 2-D index mask into the same format of a given tensor.shape.

tensor = torch.randn(1,2,3,3) index_mask = [[0,1,2], [3,4,5], [6,7,8]] foo(index_mask, tensor) -> 4D_index_max

Sorry but I’m not sure I understood the expected output. Could you give an example ?

I solved it using torch.arange and expand! Sorry for my poor explanation of the problem, and thanks for the interest