I have two Tensor objects, t1 of size (D, m, n) and t2 of size (D, n, n) and I want to perform something like a NumPy tensordot(t1,t2, axes=([0, 2], [0, 2])), that is perform 2D matrix multiplications over the axis 0 and 2 of the 3D tensors.
Is it possible to perform it in pytorch?
The simplest way I see is to use view
to merge the common dimensions into one single common dimension and then use classical 2d mm. This is an example with d=7 and n=3:
t1 = torch.rand(7,5,3)
t2 = torch.rand(7,2,3)
# put t1 and t2 into compatible shapes:
t1 = t1.transpose(1,2).contiguous().view(7*3,-1).transpose(0,1)
t2 = t2.transpose(1,2).contiguous().view(7*3,-1)
result = torch.mm(t1,t2)
The problem here is that t2 is of size (D, n, n), so in your example it should be of size (7, 3, 3), thus when I view it, how can I be sure to “flatten”, the first and third dimension, and not the first and second ones?
By default view
will “merge” the first set of dimension components that divides the target dimension size. if you have a (D,n,n) tensor and use .view(D*n, n)
, the two first dimension will be merged.
So does exist a way to merge the first and third dimension instead?
That’s why I used transposition
Aha! I didn’t notice! Thank you!
The answer above matches the behavior of np.tensordot
For those looking to do slice multiplication over 3D tensors, torch.matmul
might be better.
A = torch.tensor([[[i for i in range(1,3)]] * 5] * 2)
>>>tensor([[[1, 2],
[1, 2],
[1, 2],
[1, 2],
[1, 2]],
[[1, 2],
[1, 2],
[1, 2],
[1, 2],
[1, 2]]])
A.size()
>>>torch.Size([2, 5, 2])
B = torch.tensor([[[i for i in range(2,4)]] * 3] * 2)
>>>tensor([[[2, 3],
[2, 3],
[2, 3]],
[[2, 3],
[2, 3],
[2, 3]]])
B.size()
>>>torch.Size([2, 3, 2])
# np.tensordot result:
D = np.tensordot(A.numpy(), B.numpy(), axes=([0, 2], [0, 2]))
>>>array([[16, 16, 16],
[16, 16, 16],
[16, 16, 16],
[16, 16, 16],
[16, 16, 16]])
## The proposed solution above:
A.transpose(1,2).contiguous().view(2*2,-1).transpose(0,1)
>>>tensor([[1, 2, 1, 2],
[1, 2, 1, 2],
[1, 2, 1, 2],
[1, 2, 1, 2],
[1, 2, 1, 2]])
B.transpose(1,2).contiguous().view(2*2,-1)
>>>tensor([[2, 2, 2],
[3, 3, 3],
[2, 2, 2],
[3, 3, 3]])
torch.mm(A.transpose(1,2).contiguous().view(2*2,-1).transpose(0,1),
B.transpose(1,2).contiguous().view(2*2,-1))
>>>tensor([[16, 16, 16],
[16, 16, 16],
[16, 16, 16],
[16, 16, 16],
[16, 16, 16]])
# --> Equivalent to tensordot answer
# With matmul, we can get slice by slice:
torch.matmul(A,B.transpose(1,2))
>>>tensor([[[8, 8, 8],
[8, 8, 8],
[8, 8, 8],
[8, 8, 8],
[8, 8, 8]],
[[8, 8, 8],
[8, 8, 8],
[8, 8, 8],
[8, 8, 8],
[8, 8, 8]]])