Faster repeat_interleave?

I’m trying to get [0,0,0, 1,1,1, 2,2,2]…
tried two method, both very slow, is that a better approach?

from triton.testing import do_bench
import torch

size = 1*1024 * 1024

def _repeat_interleave(size, repeats):
    return torch.arange(size).repeat_interleave(repeats, output_size=size*d)

def _tile(size, repeats):
    return torch.arange(size).view(-1, 1).tile(1, d).flatten()

for d in [1, 2, 4, 8, 16, 32]:
    assert(torch.allclose(_repeat_interleave(size, d), _tile(size, d)))
    print(d, do_bench(lambda : _repeat_interleave(size, d)))
    print(d, do_bench(lambda : _tile(size, d)))
    print('-----')

1 0.37788161635398865
1 0.3799819350242615
-----
2 3.7355096340179443
2 3.932922124862671
-----
4 30.093482971191406
4 19.520463943481445
-----
8 52.662208557128906
8 53.49440002441406
-----
16 71.41273498535156
16 70.170654296875
-----
32 143.93128967285156
32 138.313720703125
-----