Mimic numpy.repeat's advanced functionality in pytorch

I am in the process of implementing a Message Passing Neural Network In PyTorch. Lots of implementations I see introduce zero-padding in order to fit the framework but this actually means that the implementations do not match the theory described in the reference papers especially when things like batch norm are introduced.

A true implementation would be possible if we could make use of something like numpy.repeat where each element can be repeated a certain number of times i.e.

> import numpy as np
> x = [1,2,3]
> np.repeat(x,x)
array([1, 2, 2, 3, 3, 3])

we would then have a separate array to allow us to perform sums over the numbers of repeated elements depending on the numbers of neighbours.

I appreciate that such an implementation could be slower than zero-padding but the main concern in my work is being true to the physical system and to examine how true implementations differ from those that are available.

This can be solved using broadcasting but it involves having to manually make a repeat tensor of type LongInt.

A minimal example is:

import torch
a = torch.arange(35).view((5,7))
rep = torch.tensor([0,1,1,1,3,3,2])
a_ = a[rep,:]

This outputs

tensor([[ 0,  1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12, 13],
        [ 7,  8,  9, 10, 11, 12, 13],
        [ 7,  8,  9, 10, 11, 12, 13],
        [21, 22, 23, 24, 25, 26, 27],
        [21, 22, 23, 24, 25, 26, 27],
        [14, 15, 16, 17, 18, 19, 20]])