Dot product batch-wise

I have two matrices of dimension (6, 256). I would like to calculate the dot product row-wise so that the dimensions of the resulting matrix would be (6 x 1). torch.dot does not support batch-wise calculation. Any efficient way to do this?

5 Likes

Each row is a vector with 256 elements; what do you mean by dot product? Do you just want to multiply all of those elements together?

torch.bmm(A.view(6, 1, 256), B.view(6, 256, 1)) should do the trick!

http://pytorch.org/docs/0.2.0/torch.html#torch.bmm

22 Likes

Yeah! That would do. There is no direct function then, right?

I don’t think so. Well, this is very efficient because there is no copying to massage the data :slight_smile:

3 Likes

OK, let’s say, on mac osx, cpu,

In [166]: a = th.Tensor(np.random.rand(100000, 10))

In [167]: %timeit th.sum(a*a, dim=1).sum()
2.8 ms ± 39.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [168]: %timeit th.bmm(a.view(-1,1,10), a.view(-1,10,1)).sum()
64.2 ms ± 666 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

bmm is almost 30 times slower, and furthermore bmm makes backward much more slower, what?

8 Likes

In addition, the results are similar on GPU.
For GPU version, bmm invokes separate CUDA kernels for each matrix multiplication, which in this case, 10000 kernel launches. Then the function calling overhead dominates the total computation time.

1 Like

yeah, it might not be specifically optimized for this case. thanks for doing the benchmark.

I just checked bmm is a single batched gemm kernel call. It’s not doing 10000 kernel launches.

1 Like

Yeah you’re right, it uses a single batched kernel. Actually the CPU version is a loop in batch dimension.

1 Like

If anyone came across this post via a google search, I suggest they check out the following github issue:

Given two batches of vectors A,B, it is the fastest to just compute (A*B).sum(-1)

7 Likes

I add another method using matmul() with transpose(). The order is from faster to slower:

a = torch.rand(2, 4)

%timeit (a*a).sum(1)
# 4.26 µs ± 21.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit torch.matmul(a, a.t()).diag()
# 6.81 µs ± 365 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit torch.bmm(a.view(2, 1, 4), a.view(2, 4, 1)).view(2, 1)
# 16.2 µs ± 156 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
2 Likes

In general,
torch.bmm(A.unsqueeze(dim=1), B.unsqueeze(dim=2)).squeeze()