Numpy and Theano support tensordot
. It would be nice that Pytorch also have this feature.
1 Like
If numpy supports it, PyTorch supports it, through the use of numpy()
:
import torch
import numpy as np
M = torch.randn(3, 3, 4)
v = torch.randn(2, 2, 3)
out = torch.Tensor(np.tensordot(v.numpy(), M.numpy(), axes=[[2], [0]]))
Not as concise, but it’s there.
Not exactly true, because if you want to backprop through it, you will break the graph…
4 Likes
Ah, yes, that’s true and that would be important under most circumstances.
So, is there any workaround for tensordot in pytorch?
You could use einsum
.