Hi folks, I wish to do a cross product of 2 stacks of vectors, but I’d like this operation to be parameterized. For example, I could achieve a
n x n matrix if I did
torch.matmul(h, k.permute(0, 2, 1)), where
h is of
size(batch, n, emb) and
k is of
size(batch, n, emb).
I’m wondering if there’s a way to do this such that I have a MLP module that takes in the pairs of
k and projects them? For now, what I’ve done is this
h = h.repeat_interleave(n, dim=1) k = k.repeat(1, n, 1) hk = torch.cat((h, k), dim=1) j = MLP(hk)
But this blows up the GPU memory. Is there a more efficient way of doing this?