How can I do a mutual concatenation

Hi guys,

I’ve confronted with a problem that I cannot solve. Say I have a tensor A with shape (batch_A, hidden_dim) and tensor B with shape (batch_B, hidden_dim). Now I would like to do a reshape for both tensor into shape (batch_A, h, hidden_dim/h), (batch_B, h, hidden_dim/h),the mutual concatenation into (batch_A x batch_B, 2h, hidden_dim/h) and then perform Conv on it. What would be the best way to do such operation. As far as I know, I could use this:

A = A.reshape(-1, h, hidden_dim/h)
B = B.reshape(-1, h, hidden_dim/h)
A = A.unsqueeze(0).repeat(batch_B,1, 1, 1)
B = B.unsqueeze(1).repeat(1, batch_A, 1, 1)
concat_tensor = torch.cat([A, B], dim=-2)
res = conv(concat_tensor)

Is there any better way to do this? As repeat() will consume a lot of memory. This would not ideal for me because one of my batch size would be as large as 15,000.

Regards,