I’m using torch version 1.8.1. I imported vmap:
from torch._vmap_internals import vmap
I have read the documentation but I’m still confused about the syntax. First, I’ll write down what my objective is.
I have a matrix. I want to process the columns as a batch. Each column is multiplied by a random scalar. So each column results in an equal size column. Is this possible with vmap
?
Please avoid suggesting alternative ways (without vmap
) to do the above. This is just an example I chose to understand vmap
.
What I assume:
- The first argument to
vmap
is a functionfunc
. The result is a vectorized functionvfunc
. - The vectorized function
vfunc
appliesfunc
to each column of the given matrix.
What I have tried:
def generate_scalars(c):
scalar = 1 if random() > 0.5 else -1
res = torch.tensor(c.size())
res[:] = scalar
return res.T
The result of the above is a single column and not a matrix as I expected. I’m sure I’m misunderstanding some part of this.