I’m using torch version 1.8.1. I imported vmap:
from torch._vmap_internals import vmap
I have read the documentation but I’m still confused about the syntax. First, I’ll write down what my objective is.
I have a matrix. I want to process the columns as a batch. Each column is multiplied by a random scalar. So each column results in an equal size column. Is this possible with
Please avoid suggesting alternative ways (without
vmap) to do the above. This is just an example I chose to understand
What I assume:
- The first argument to
vmapis a function
func. The result is a vectorized function
- The vectorized function
functo each column of the given matrix.
What I have tried:
def generate_scalars(c): scalar = 1 if random() > 0.5 else -1 res = torch.tensor(c.size()) res[:] = scalar return res.T
The result of the above is a single column and not a matrix as I expected. I’m sure I’m misunderstanding some part of this.