I’m using torch version 1.8.1. I imported vmap:

`from torch._vmap_internals import vmap`

I have read the documentation but I’m still confused about the syntax. First, I’ll write down what my objective is.

I have a matrix. I want to process the columns as a batch. Each column is multiplied by a random scalar. So each column results in an equal size column. Is this possible with `vmap`

?

Please avoid suggesting alternative ways (without `vmap`

) to do the above. This is just an example I chose to understand `vmap`

.

What I assume:

- The first argument to
`vmap`

is a function`func`

. The result is a vectorized function`vfunc`

. - The vectorized function
`vfunc`

applies`func`

to each column of the given matrix.

What I have tried:

```
def generate_scalars(c):
scalar = 1 if random() > 0.5 else -1
res = torch.tensor(c.size())
res[:] = scalar
return res.T
```

The result of the above is a single column and not a matrix as I expected. I’m sure I’m misunderstanding some part of this.