Multiply each 2D matrix in 3D tensor to another 3D tensor

    tensor_a is of shape (256, 4, 64) and tensor_b is of shape (256, 64, 1)
    res = None
    for u in range(256):
        temp = torch.matmul(tensor_a[u, :, :], tensor_b)
        temp = temp.sum().unsqueeze(-1)
        if u == 0:
            res = temp
        else:
            res = torch.cat((res, temp), -1)

    res = res.unsqueeze(-1)

Can anyone get the result using torch functions rather than use looping?

You could broadcast one tensor, which would use more memory for a potential performance gain:

out = torch.matmul(tensor_a.unsqueeze(1), tensor_b)
out = out.sum([1, 2])
print((out - res).abs().max())
> tensor(2.5580e-13, dtype=torch.float64)

Note that I used DoubleTensors to check the error, as the numerical errors could be accumulating quite a lot.

1 Like