Support for tensordot

Numpy and Theano support tensordot. It would be nice that Pytorch also have this feature.

If numpy supports it, PyTorch supports it, through the use of numpy():

import torch
import numpy as np

M = torch.randn(3, 3, 4)
v = torch.randn(2, 2, 3)

out = torch.Tensor(np.tensordot(v.numpy(), M.numpy(), axes=[[2], [0]]))

Not as concise, but it’s there.

Not exactly true, because if you want to backprop through it, you will break the graph…


Ah, yes, that’s true and that would be important under most circumstances.

So, is there any workaround for tensordot in pytorch?

I adapted the numpy tensordot implementation to pytorch: pytorch tensordot

You could use einsum.