# Multiply two 3D tensors along different dims

I have two Tensor objects, t1 of size (D, m, n) and t2 of size (D, n, n) and I want to perform something like a NumPy tensordot(t1,t2, axes=([0, 2], [0, 2])), that is perform 2D matrix multiplications over the axis 0 and 2 of the 3D tensors.
Is it possible to perform it in pytorch?

The simplest way I see is to use `view` to merge the common dimensions into one single common dimension and then use classical 2d mm. This is an example with d=7 and n=3:

``````t1 = torch.rand(7,5,3)
t2 = torch.rand(7,2,3)
# put t1 and t2 into compatible shapes:
t1 = t1.transpose(1,2).contiguous().view(7*3,-1).transpose(0,1)
t2 = t2.transpose(1,2).contiguous().view(7*3,-1)
result = torch.mm(t1,t2)
``````

The problem here is that t2 is of size (D, n, n), so in your example it should be of size (7, 3, 3), thus when I view it, how can I be sure to “flatten”, the first and third dimension, and not the first and second ones?

By default `view` will “merge” the first set of dimension components that divides the target dimension size. if you have a (D,n,n) tensor and use `.view(D*n, n)`, the two first dimension will be merged.

So does exist a way to merge the first and third dimension instead?

That’s why I used transposition Aha! I didn’t notice! Thank you!

The answer above matches the behavior of `np.tensordot`

For those looking to do slice multiplication over 3D tensors, `torch.matmul` might be better.

``````A = torch.tensor([[[i for i in range(1,3)]] * 5] * 2)
>>>tensor([[[1, 2],
[1, 2],
[1, 2],
[1, 2],
[1, 2]],

[[1, 2],
[1, 2],
[1, 2],
[1, 2],
[1, 2]]])
A.size()
>>>torch.Size([2, 5, 2])

B = torch.tensor([[[i for i in range(2,4)]] * 3] * 2)
>>>tensor([[[2, 3],
[2, 3],
[2, 3]],

[[2, 3],
[2, 3],
[2, 3]]])
B.size()
>>>torch.Size([2, 3, 2])

# np.tensordot result:
D = np.tensordot(A.numpy(), B.numpy(), axes=([0, 2], [0, 2]))
>>>array([[16, 16, 16],
[16, 16, 16],
[16, 16, 16],
[16, 16, 16],
[16, 16, 16]])

## The proposed solution above:
A.transpose(1,2).contiguous().view(2*2,-1).transpose(0,1)
>>>tensor([[1, 2, 1, 2],
[1, 2, 1, 2],
[1, 2, 1, 2],
[1, 2, 1, 2],
[1, 2, 1, 2]])

B.transpose(1,2).contiguous().view(2*2,-1)
>>>tensor([[2, 2, 2],
[3, 3, 3],
[2, 2, 2],
[3, 3, 3]])

torch.mm(A.transpose(1,2).contiguous().view(2*2,-1).transpose(0,1),
B.transpose(1,2).contiguous().view(2*2,-1))
>>>tensor([[16, 16, 16],
[16, 16, 16],
[16, 16, 16],
[16, 16, 16],
[16, 16, 16]])
# --> Equivalent to tensordot answer

# With matmul, we can get slice by slice:
torch.matmul(A,B.transpose(1,2))
>>>tensor([[[8, 8, 8],
[8, 8, 8],
[8, 8, 8],
[8, 8, 8],
[8, 8, 8]],

[[8, 8, 8],
[8, 8, 8],
[8, 8, 8],
[8, 8, 8],
[8, 8, 8]]])

``````
1 Like