How can i compute 3D tensor * 2D tensor multiplication?

In numpy, when i have a 3D tensor X with shape [A, B, C] and a 2D tensor Y with shape [C, D], then np.dot(X, Y) gives a 3D tensor with shape [A, B, D].

In PyTorch, i can do this as below.

result = torch.mm(X.view(-1, C), Y)
result = result.view(-1, B, D)

Is there simpler way to handle this?

4 Likes

do you mean entrywise product? you want broadcast like numpy?
maybe you can try expand_as and unsqueeze

Y.unsqueeze(0).expand_as(X) * X
2 Likes

Y.unsqueeze(0) has a shape (1, C, D) so that it can not be expanded as X which has a shape (A, B, C).

What i want to do is to broadcast 2D matrix multiplication as below.

def matmul(X, Y):
    results = [] 
    for i in range(X.size(0)):
        result = torch.mm(X[i], Y)
        results.append(result)
    return torch.cat(results)
1 Like

@yunjey at the moment there is no simpler way. we plan to add broadcast semantics into pytorch

3 Likes

Ok, I got it. Thanks.

@chenyuntc, what you suggest would work but it’s an elementwise multiplication. @yunjey for the dot product, in pytorch it seems to only support 2D tensors. So yes, for the moment you have to vectorize (A and B) into one vector (for instance using view, or you can also use resize for almost simpler code:

result = torch.mm(X.resize_(A*B,C), Y).resize_(A,B,D)

My mistake, I was careless

I created operators for pytorch with broadcasting - you can grab them from Tip: using keras compatible tensor dot product and broadcasting ops

1 Like

Also, you be able to do that with batched matrix multiply:

result = torch.bmm(X, Y.unsqueeze(0).expand(X.size(0), *Y.size()))

We’ll be adding broadcasting very soon.

2 Likes

@alexis-jacq don’t use resize_ on tensors that you want to use later! The contents of the tensor are unspecified after you call it.

@yunjey, numpy-compatible dot was recently published on the forums Tip: using keras compatible tensor dot product and broadcasting ops by @jphoward.

For future viewers:
There is a function in torch einsum:
output= torch.einsum("abc,cd->abd", (X, Y))

9 Likes