Example of how to live without einsum in pytorch

The code speaks for itself. I hope it is correct and will be useful for you.

def SelfTest():
    N, D, P, Q = 10, 64, 5, 5
    
    batch1 = torch.randn(N, D, P * 3, Q * 3)
    batch2 = torch.randn(N, D, P * 3, Q * 3)
    
    def Transform(x):
        return x.view(N, D, P, 3, Q, 3).contiguous().permute(0, 3, 5, 1, 2, 4).contiguous().view(N * 3 * 3, D, P * Q)
    
    A = Transform(batch1)
    B = Transform(batch2)
    res2 = torch.bmm(A, B.permute(0, 2, 1)).view(N, D * D, 3, 3)
    print(res2.size(), np.sum(res2.numpy()))
    
    # numpy version
    A = np.reshape(batch1.numpy(), (N, D, 3 * 3, P * Q))
    B = np.reshape(batch2.numpy(), (N, D, 3 * 3, P * Q))
    res3 = np.einsum('niab,njab->nija', A, B)
    res3 = np.reshape(res3, (N, D * D, 3, 3))
    print(res3.shape, np.sum(res3))

torch.matmul can take in batched matrices with arbitrary batch dimensions so you won’t need to do the extra view to flatten those.