Expand different number of times each index of a tensor dimension

Hey

Let’s say I have a tensor like:

a =  [[1, 1, 1],
      [2, 2, 2],
      [3, 3, 3]]

and that I want to expand each row a certain amount of times, without actually allocating memory. If I want each row to be repeated the same number of times I could do something like:

a = a.unsqueeze(1).expand(-1, 2, -1).reshape(-1, 3)

Which gives:

a =  [[1, 1, 1],
      [1, 1, 1],
      [2, 2, 2],
      [2, 2, 2],
      [3, 3, 3],
      [3, 3, 3]]

However, how would I do it efficiently (mostly without allocating unnecessary memory) if I wanted to expand each row a different amount of times? Something like:

a = a.unsqueeze(1).expand(-1, (2,3,4), -1).reshape(-1, 3)

Which would yield:

a =  [[1, 1, 1],
      [1, 1, 1],
      [2, 2, 2],
      [2, 2, 2],
      [2, 2, 2],
      [3, 3, 3],
      [3, 3, 3],
      [3, 3, 3],
      [3, 3, 3]]

You can use torch.index_select