Thanks to this post (Repeat examples along batch dimension), I used torch.repeat_interleave to do it.
import torch
torch.manual_seed(0)
shape = x, y = 2, 3
"""
or :
- shape = x, y, z = _, _, _, ...
- shape = x, y, z, t, ... = _, _, _, _, ...
...
"""
n = 3
t = torch.rand(shape)
"""
tensor([[0.4963, 0.7682, 0.0885],
[0.1320, 0.3074, 0.6341]])
"""
t_interpolate = torch.repeat_interleave(t, repeats=n, dim=0)
"""
tensor([[0.4963, 0.7682, 0.0885],
[0.4963, 0.7682, 0.0885],
[0.4963, 0.7682, 0.0885],
[0.1320, 0.3074, 0.6341],
[0.1320, 0.3074, 0.6341],
[0.1320, 0.3074, 0.6341]])
"""
assert t_interpolate.size() == torch.Size([n * shape[0], *shape[1:]])