Support for tensordot

Numpy and Theano support tensordot. It would be nice that Pytorch also have this feature.

1 Like

If numpy supports it, PyTorch supports it, through the use of numpy():

import torch
import numpy as np

M = torch.randn(3, 3, 4)
v = torch.randn(2, 2, 3)

out = torch.Tensor(np.tensordot(v.numpy(), M.numpy(), axes=[[2], [0]]))

Not as concise, but it’s there.

Not exactly true, because if you want to backprop through it, you will break the graph…

4 Likes

Ah, yes, that’s true and that would be important under most circumstances.

So, is there any workaround for tensordot in pytorch?

I adapted the numpy tensordot implementation to pytorch: pytorch tensordot

You could use einsum.