I have a matrix A with dimension [batch_size,N,M,D] and another matrix B with dimension [batch_size,P,D]. I want to get a tensor C as output with dimension [batch_size,N,M,P] in the following way:
C[i,j] = matrix_dot_product(A[i,j], B[i]) 0<=i<batch_size, 0<=j<N
What is the most memory-efficient way to do this?
Thanks!
import torch
batch_size, N, M, D, P = 4, 5, 6, 7, 8
A = torch.rand(batch_size, N, M, D)
B = torch.rand(batch_size, P, D)
print(torch.einsum('bnmd,bpd -> bnmp', [x, w]).size()) # batch_size, N, M, P
It does not seem that it is giving the correct result.
> C = torch.zeros(batch_size, N,M,P)
> for i in range(batch_size):
> for j in range(N):
> C[i,j] = torch.mm(A[i,j],B[i].transpose(0,1))
> is_same = C==torch.einsum('bnmd,bpd -> bnmp', [A, B])
import torch
batch_size, N, M, D, P = 2, 2, 2, 2, 3
A = torch.arange(16).view(batch_size, N, M, D)
B = torch.arange(16, 28).view(batch_size, P, D)
C = torch.zeros(batch_size, N, M, P)
for i in range(batch_size):
for j in range(N):
C[i,j] = torch.mm(A[i,j],B[i].transpose(0,1))
F = torch.einsum('bnmd,bpd -> bnmp', [A, B])
print(F.equal(C.long())) # Prints True