What can I use instead of 'MixtureTable' of Torch

I have two kinds of Tensor,

a = torch.LongTensor([1, 2, 3])
b = [torch.LongTensor([1, 1]), torch.LongTensor([2, 2]), torch.LongTensor([3,3])]

then, i want to do a operation like nn.MixtureTable in Torch.

c = something(a, b)
print(c)

result will be like this,

LongTensor - size: 1x2
c = [14, 14]

Are there any operation to do that? or, Do I have to implement this by myself?

Try with torch.stack(b,dim=1).mul(a).sum(dim=1). It should work both with Tensors as well as Variables.

1 Like

@lantiga
Thank you for your replying.

There is a issue that two tensor of different sizes cannot be multiplied. but I can be aware of your idea.
The following is an example for 2-rank tensor,

a = torch.LongTensor([[1,2,3],[4,5,6]])
b = [torch.LongTensor([[1,1],[1,1]]), torch.LongTensor([[2,2],[4,4]]), torch.LongTensor([[3,3],[9,9]])]
b = torch.stack(b, dim=2)
a = a.view(a.size(0),1,a.size(1)).expand_as(b)

c = b.mul(a).sum(dim=2).view(b.size(0), b.size(1))

If anyone has better idea, please let me know.
Thanks.

Equivalent, but a bit more concise:

a = torch.LongTensor([[1,2,3],[4,5,6]])
b = [torch.LongTensor([[1,1],[1,1]]),
       torch.LongTensor([[2,2],[4,4]]),
       torch.LongTensor([[3,3],[9,9]])]

b = torch.stack(b, dim=2)

c = b.mul(a.unsqueeze(1)).sum(dim=2).squeeze()

Oh right! ‘squeeze’ and ‘unsqueeze’ are more suitable than ‘view’.
but there is still the problem that two tensor of different size cannot be multiplied element-wisely, not different rank.
(i.e. a (2x1x3 tensor) * b (2x2x3 tensor) , ‘*’ denotes element-wise multiplication)
it can be solved by using ‘expand’

so, I suggest the following for people watching this discussion,

a = torch.LongTensor([[1,2,3],[4,5,6]])
b = [torch.LongTensor([[1,1],[1,1]]),
     torch.LongTensor([[2,2],[4,4]]),
     torch.LongTensor([[3,3],[9,9]])]

b = torch.stack(b, dim=2)

c = b.mul(a.unsqueeze(1).expand_as(b)).sum(dim=2).squeeze()

I really appreciate for your helpful comments.
@lantiga

Hi @btjhjeon, indeed! My bad, I left out the expand_as.

1 Like