How to do large-scale outer product efficiently

if you have a tensor A of shape [m,M] and another tensor B of shape [m,N] you can do an outer product via torch.einsum,

outer_product = torch.einsum("bi,bj->bij",A,B)