What is the best way to insert a single fixed value into a tensor dimension before every N
elements?
This is what I’ve come up with so far -
import torch
T, N, C = (720, 32, 256)
x = torch.empty((T, N, C))
M = 4
value = torch.tensor(24)
x = torch.cat([
value.expand(T, N, C // M, 1),
x.reshape(T, N, C // M, M)
], axis=-1).reshape(T, N, -1)
print(x.shape) # torch.Size([720, 32, 320])