The code speaks for itself. I hope it is correct and will be useful for you.
def SelfTest():
N, D, P, Q = 10, 64, 5, 5
batch1 = torch.randn(N, D, P * 3, Q * 3)
batch2 = torch.randn(N, D, P * 3, Q * 3)
def Transform(x):
return x.view(N, D, P, 3, Q, 3).contiguous().permute(0, 3, 5, 1, 2, 4).contiguous().view(N * 3 * 3, D, P * Q)
A = Transform(batch1)
B = Transform(batch2)
res2 = torch.bmm(A, B.permute(0, 2, 1)).view(N, D * D, 3, 3)
print(res2.size(), np.sum(res2.numpy()))
# numpy version
A = np.reshape(batch1.numpy(), (N, D, 3 * 3, P * Q))
B = np.reshape(batch2.numpy(), (N, D, 3 * 3, P * Q))
res3 = np.einsum('niab,njab->nija', A, B)
res3 = np.reshape(res3, (N, D * D, 3, 3))
print(res3.shape, np.sum(res3))