Batching `torch.tensordot` using `torch.vmap`

Hi,

I want to implement batched linear combination of batches bs of tensors b of arbitrary shape with batches as of coefficients a. This means bs has size [batch_size, n0, n1, ..., nk ] and as has size [batch_size, n0]. The output is of size [batch_size, n1, ..., nk]
The following does what I want:

def cust_td(a,b):
   return torch.tensordot(a, b, dims=1)

batch_td = torch.vmap(cust_td)
batch_td(as, bs)

Is this fine (from an efficiency perspective?).

On another note this question maybe motivates allowing torch.vmap to get keyword args (e.g. dims=1 in my case) and pass them to the function func before vectorizing it.