For a tensor `a`

in shape of 1 x n, and `b`

in batch x m x n, is there any elegant way to concatenate `a`

to the second dimension of `b`

so that the result should in shape of batch x (m + 1) x n. Meanwhile, the gradient in different batch can effect the same origin `a`

by `loss.backward()`

.

You can decompose cat into allocating an empty tensor and assigning slices (which broadcasts)

```
res = torch.empty((b, m + 1, n), device=a.device, dtype=a.dtype)
res[:, :1] = a[None]
res[:, 1:] = b[:, :, None]
```

Best regards

Thomas

1 Like