Parameterized cross product of matrices?

Hi folks, I wish to do a cross product of 2 stacks of vectors, but I’d like this operation to be parameterized. For example, I could achieve a n x n matrix if I did torch.matmul(h, k.permute(0, 2, 1)), where h is of size(batch, n, emb) and k is of size(batch, n, emb).

I’m wondering if there’s a way to do this such that I have a MLP module that takes in the pairs of h and k and projects them? For now, what I’ve done is this

h = h.repeat_interleave(n, dim=1)
k = k.repeat(1, n, 1)

hk = torch.cat((h, k), dim=1)

j = MLP(hk)

But this blows up the GPU memory. Is there a more efficient way of doing this?