I’ve confronted with a problem that I cannot solve. Say I have a tensor A with shape (batch_A, hidden_dim) and tensor B with shape (batch_B, hidden_dim). Now I would like to do a reshape for both tensor into shape (batch_A, h, hidden_dim/h), (batch_B, h, hidden_dim/h),the mutual concatenation into (batch_A x batch_B, 2h, hidden_dim/h) and then perform Conv on it. What would be the best way to do such operation. As far as I know, I could use this:
A = A.reshape(-1, h, hidden_dim/h) B = B.reshape(-1, h, hidden_dim/h) A = A.unsqueeze(0).repeat(batch_B,1, 1, 1) B = B.unsqueeze(1).repeat(1, batch_A, 1, 1) concat_tensor = torch.cat([A, B], dim=-2) res = conv(concat_tensor)
Is there any better way to do this? As
repeat() will consume a lot of memory. This would not ideal for me because one of my batch size would be as large as 15,000.