Consider the tensors,
x = tensor([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]])
y = tensor([0., 0.])
I want to batched torch.hstack
to output something like,
torch.stack([torch.hstack((i, y)) for i in x])
which outputs,
tensor([[1., 1., 0., 0.],
[1., 1., 0., 0.],
[1., 1., 0., 0.],
[1., 1., 0., 0.],
[1., 1., 0., 0.]])
I guess python list expansions are slow and probably will not propagate the gradient across the computation graph (or will it?). So I wanted to convert it to torch.func.vmap
with something like,
torch.vmap(torch.hstack, in_dims=(0, None))((x, y))
but obviously that doesn’t work because input to torch.hstack
is a tuple, not a unpacked *tuple. What would be the correct way to do it?