Hi,
I want to implement batched linear combination of batches bs
of tensors b
of arbitrary shape with batches as
of coefficients a
. This means bs
has size [batch_size, n0, n1, ..., nk ]
and as
has size [batch_size, n0]
. The output is of size [batch_size, n1, ..., nk]
The following does what I want:
def cust_td(a,b):
return torch.tensordot(a, b, dims=1)
batch_td = torch.vmap(cust_td)
batch_td(as, bs)
Is this fine (from an efficiency perspective?).
On another note this question maybe motivates allowing torch.vmap
to get keyword args (e.g. dims=1
in my case) and pass them to the function func
before vectorizing it.