Vmap confusions: what is the output?

I’m using torch version 1.8.1. I imported vmap:
from torch._vmap_internals import vmap

I have read the documentation but I’m still confused about the syntax. First, I’ll write down what my objective is.

I have a matrix. I want to process the columns as a batch. Each column is multiplied by a random scalar. So each column results in an equal size column. Is this possible with vmap?

Please avoid suggesting alternative ways (without vmap) to do the above. This is just an example I chose to understand vmap.

What I assume:

  • The first argument to vmap is a function func. The result is a vectorized function vfunc.
  • The vectorized function vfunc applies func to each column of the given matrix.

What I have tried:

 def generate_scalars(c):
     scalar = 1 if random() > 0.5 else -1
     res = torch.tensor(c.size())
     res[:] = scalar
     return res.T

The result of the above is a single column and not a matrix as I expected. I’m sure I’m misunderstanding some part of this.