How to vmap inside a tuple input?

Consider the tensors,

x = tensor([[1., 1.],
            [1., 1.],
            [1., 1.],
            [1., 1.],
            [1., 1.]])
y = tensor([0., 0.])

I want to batched torch.hstack to output something like,

torch.stack([torch.hstack((i, y)) for i in x])

which outputs,

tensor([[1., 1., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 0., 0.]])

I guess python list expansions are slow and probably will not propagate the gradient across the computation graph (or will it?). So I wanted to convert it to torch.func.vmap with something like,

torch.vmap(torch.hstack, in_dims=(0, None))((x, y))

but obviously that doesn’t work because input to torch.hstack is a tuple, not a unpacked *tuple. What would be the correct way to do it?

Concatenating tensors will not break the computation graph as seen here:

x = torch.tensor([[1., 1.],
                  [1., 1.],
                  [1., 1.],
                  [1., 1.],
                  [1., 1.]], requires_grad=True)
y = torch.tensor([0., 0.], requires_grad=True)

z = torch.cat((x, y.unsqueeze(0).expand_as(x)), dim=1)
print(z)
# tensor([[1., 1., 0., 0.],
#         [1., 1., 0., 0.],
#         [1., 1., 0., 0.],
#         [1., 1., 0., 0.],
#         [1., 1., 0., 0.]], grad_fn=<CatBackward0>)

z.mean().backward()
print(x.grad)
# tensor([[0.0500, 0.0500],
#         [0.0500, 0.0500],
#         [0.0500, 0.0500],
#         [0.0500, 0.0500],
#         [0.0500, 0.0500]])
print(y.grad)
# tensor([0.2500, 0.2500])

so you might not need to use vmap for a simple concatenation.

z = torch.cat((x, y.unsqueeze(0).expand_as(x)), dim=1)

Didn’t know about the expand_as() function. Thank you.