(More) Batch matrix operations

Will there be more batch matrix operations, such as a batch versions of:

• tril
• triu
• diag
• trace
?

I can think of hacky ways to implement them by creating batch mask matrices, but that doesn’t seem efficient.

5 Likes

at this time we dont have better plans for these than a naive for-loop.
What are your use-cases that these are needed? Can you elaborate?

I’m trying to implmement Normalized Advantage Function. In this case, you need to compute x^T P x in batch mode.

There’s already an implementation available here, but it requires creating a mask and multiplying the resulting matrix by that. It works for now, but I was hoping to play around with variations of NAF, and these methods would help me.

I’m not sure about your usage of \$x^{\top}Px\$, Bilinear layer might be helpful.

Is there an example for the usage of Bilinear Layer?

I wonder, for example, how could one use it in MNIST example.

I think there is no way to do it in 0.4.0.
By looking at the doc of the latest version (0.5), I found support for batch diagonal and `einsum('bii->bi', (batch_matrix,))`.

Just as a note regarding batch diag. Let’s assume that we want to find the diagonals of N matrices of the same sizes (R, C) stacked in a variable `x` of size (N, R, C).

``````N = 2
R = 5
C = 3
x = torch.rand(N, R, C)
``````

One way to get the diagonals is as follows:

``````x[[...] + [torch.arange(min(x.size()[-2:]), dtype=torch.long)] * 2]
``````

However, this is very inefficient. The function that is implemented in pytorch versions strictly greater than 0.4.0:

``````torch.diagonal(x, dim1=-2, dim2=-1)
``````

A temporary equivalent solution for pytorch 0.4.0 that only makes a view of the diagonals:

``````torch.from_numpy(x.numpy().diagonal(axis1=-2, axis2=-1))
``````

Here is a comparison for the running time:

``````%timeit x[[...] + [torch.arange(min(x.size()[-2:]), dtype=torch.long)] * 2]
# 23.8 µs ± 687 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit torch.from_numpy(x.numpy().diagonal(axis1=-2, axis2=-1))
# 3.02 µs ± 102 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
``````

We, also, note that it is slightly more efficient to stack the matrices as (R, C, N) where `axis1=0` and `axis2=1`:

``````x = torch.rand(R, C, N)
%timeit x[[torch.arange(min(x.size()[:2]), dtype=torch.long)] * 2 + [...]].t()
# 19.6 µs ± 662 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit torch.from_numpy(x.numpy().diagonal(axis1=0, axis2=1))
# 2.58 µs ± 73.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
``````