Concatenate a tensor to sub-dimensions

For a tensor a in shape of 1 x n, and b in batch x m x n, is there any elegant way to concatenate a to the second dimension of b so that the result should in shape of batch x (m + 1) x n. Meanwhile, the gradient in different batch can effect the same origin a by loss.backward().

You can decompose cat into allocating an empty tensor and assigning slices (which broadcasts)

res = torch.empty((b, m + 1, n), device=a.device, dtype=a.dtype)
res[:, :1] = a[None]
res[:, 1:] = b[:, :, None]

Best regards


1 Like