Multiply two 3D tensors along different dims

I have two Tensor objects, t1 of size (D, m, n) and t2 of size (D, n, n) and I want to perform something like a NumPy tensordot(t1,t2, axes=([0, 2], [0, 2])), that is perform 2D matrix multiplications over the axis 0 and 2 of the 3D tensors.
Is it possible to perform it in pytorch?

The simplest way I see is to use view to merge the common dimensions into one single common dimension and then use classical 2d mm. This is an example with d=7 and n=3:

t1 = torch.rand(7,5,3)
t2 = torch.rand(7,2,3)
# put t1 and t2 into compatible shapes:
t1 = t1.transpose(1,2).contiguous().view(7*3,-1).transpose(0,1)
t2 = t2.transpose(1,2).contiguous().view(7*3,-1)
result = torch.mm(t1,t2)

The problem here is that t2 is of size (D, n, n), so in your example it should be of size (7, 3, 3), thus when I view it, how can I be sure to “flatten”, the first and third dimension, and not the first and second ones?

By default view will “merge” the first set of dimension components that divides the target dimension size. if you have a (D,n,n) tensor and use .view(D*n, n), the two first dimension will be merged.

So does exist a way to merge the first and third dimension instead?

That’s why I used transposition :wink:

Aha! I didn’t notice! Thank you!

The answer above matches the behavior of np.tensordot

For those looking to do slice multiplication over 3D tensors, torch.matmul might be better.

A = torch.tensor([[[i for i in range(1,3)]] * 5] * 2)
>>>tensor([[[1, 2],
         [1, 2],
         [1, 2],
         [1, 2],
         [1, 2]],

        [[1, 2],
         [1, 2],
         [1, 2],
         [1, 2],
         [1, 2]]])
A.size()
>>>torch.Size([2, 5, 2])

B = torch.tensor([[[i for i in range(2,4)]] * 3] * 2)
>>>tensor([[[2, 3],
         [2, 3],
         [2, 3]],

        [[2, 3],
         [2, 3],
         [2, 3]]])
B.size()
>>>torch.Size([2, 3, 2])

# np.tensordot result:
D = np.tensordot(A.numpy(), B.numpy(), axes=([0, 2], [0, 2]))
>>>array([[16, 16, 16],
       [16, 16, 16],
       [16, 16, 16],
       [16, 16, 16],
       [16, 16, 16]])

## The proposed solution above:
A.transpose(1,2).contiguous().view(2*2,-1).transpose(0,1)
>>>tensor([[1, 2, 1, 2],
        [1, 2, 1, 2],
        [1, 2, 1, 2],
        [1, 2, 1, 2],
        [1, 2, 1, 2]])

B.transpose(1,2).contiguous().view(2*2,-1)
>>>tensor([[2, 2, 2],
        [3, 3, 3],
        [2, 2, 2],
        [3, 3, 3]])

torch.mm(A.transpose(1,2).contiguous().view(2*2,-1).transpose(0,1),
         B.transpose(1,2).contiguous().view(2*2,-1))
>>>tensor([[16, 16, 16],
        [16, 16, 16],
        [16, 16, 16],
        [16, 16, 16],
        [16, 16, 16]])
# --> Equivalent to tensordot answer


# With matmul, we can get slice by slice:
torch.matmul(A,B.transpose(1,2))
>>>tensor([[[8, 8, 8],
         [8, 8, 8],
         [8, 8, 8],
         [8, 8, 8],
         [8, 8, 8]],

        [[8, 8, 8],
         [8, 8, 8],
         [8, 8, 8],
         [8, 8, 8],
         [8, 8, 8]]])

1 Like