(More) Batch matrix operations

Will there be more batch matrix operations, such as a batch versions of:

  • tril
  • triu
  • diag
  • trace

I can think of hacky ways to implement them by creating batch mask matrices, but that doesn’t seem efficient.


at this time we dont have better plans for these than a naive for-loop.
What are your use-cases that these are needed? Can you elaborate?

I’m trying to implmement Normalized Advantage Function. In this case, you need to compute x^T P x in batch mode.

There’s already an implementation available here, but it requires creating a mask and multiplying the resulting matrix by that. It works for now, but I was hoping to play around with variations of NAF, and these methods would help me.

I’m not sure about your usage of $x^{\top}Px$, Bilinear layer might be helpful.

Is there an example for the usage of Bilinear Layer?

I wonder, for example, how could one use it in MNIST example.

I think there is no way to do it in 0.4.0.
By looking at the doc of the latest version (0.5), I found support for batch diagonal and einsum('bii->bi', (batch_matrix,)).

Just as a note regarding batch diag. Let’s assume that we want to find the diagonals of N matrices of the same sizes (R, C) stacked in a variable x of size (N, R, C).

N = 2
R = 5
C = 3
x = torch.rand(N, R, C)

One way to get the diagonals is as follows:

x[[...] + [torch.arange(min(x.size()[-2:]), dtype=torch.long)] * 2]

However, this is very inefficient. The function that is implemented in pytorch versions strictly greater than 0.4.0:

torch.diagonal(x, dim1=-2, dim2=-1)

A temporary equivalent solution for pytorch 0.4.0 that only makes a view of the diagonals:

torch.from_numpy(x.numpy().diagonal(axis1=-2, axis2=-1))

Here is a comparison for the running time:

%timeit x[[...] + [torch.arange(min(x.size()[-2:]), dtype=torch.long)] * 2]
# 23.8 µs ± 687 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit torch.from_numpy(x.numpy().diagonal(axis1=-2, axis2=-1))
# 3.02 µs ± 102 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

We, also, note that it is slightly more efficient to stack the matrices as (R, C, N) where axis1=0 and axis2=1:

x = torch.rand(R, C, N)
%timeit x[[torch.arange(min(x.size()[:2]), dtype=torch.long)] * 2 + [...]].t()
# 19.6 µs ± 662 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit torch.from_numpy(x.numpy().diagonal(axis1=0, axis2=1))
# 2.58 µs ± 73.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)