Batched meshgrid?

Normally, meshgrid is computed with a list of 1d tensors. I’d like to do this for a list of batched 1d tensors (keeping the batches independent). Is this possible?

Normal:

torch.meshgrid(1d Tensor, 1d Tensor, ...)

Want:

torch.meshgrid(BxD Tensor, BxD Tensor, ...)

Related: Compute cartesian product for batched tensor

Equivalent to:

meshes = []
for i in range(batch_size):
     meshes.append(torch.meshgrid(tensor1[i], tensor2[i], ...))

And then stacking the zipped meshes along the batched dimension. Obviously this is really inefficient though.

1 Like