Hi folks, I wish to do a cross product of 2 stacks of vectors, but I’d like this operation to be parameterized. For example, I could achieve a n x n
matrix if I did torch.matmul(h, k.permute(0, 2, 1))
, where h
is of size(batch, n, emb)
and k
is of size(batch, n, emb)
.
I’m wondering if there’s a way to do this such that I have a MLP module that takes in the pairs of h
and k
and projects them? For now, what I’ve done is this
h = h.repeat_interleave(n, dim=1)
k = k.repeat(1, n, 1)
hk = torch.cat((h, k), dim=1)
j = MLP(hk)
But this blows up the GPU memory. Is there a more efficient way of doing this?