For a tensor a
in shape of 1 x n, and b
in batch x m x n, is there any elegant way to concatenate a
to the second dimension of b
so that the result should in shape of batch x (m + 1) x n. Meanwhile, the gradient in different batch can effect the same origin a
by loss.backward()
.
You can decompose cat into allocating an empty tensor and assigning slices (which broadcasts)
res = torch.empty((b, m + 1, n), device=a.device, dtype=a.dtype)
res[:, :1] = a[None]
res[:, 1:] = b[:, :, None]
Best regards
Thomas
1 Like