Will there be more batch matrix operations, such as a batch versions of:
- tril
- triu
- diag
- trace
?
I can think of hacky ways to implement them by creating batch mask matrices, but that doesn’t seem efficient.
Will there be more batch matrix operations, such as a batch versions of:
I can think of hacky ways to implement them by creating batch mask matrices, but that doesn’t seem efficient.
at this time we dont have better plans for these than a naive for-loop.
What are your use-cases that these are needed? Can you elaborate?
I’m trying to implmement Normalized Advantage Function. In this case, you need to compute x^T P x in batch mode.
There’s already an implementation available here, but it requires creating a mask and multiplying the resulting matrix by that. It works for now, but I was hoping to play around with variations of NAF, and these methods would help me.
I’m not sure about your usage of $x^{\top}Px$, Bilinear layer might be helpful.
Is there an example for the usage of Bilinear Layer?
I wonder, for example, how could one use it in MNIST example.
I think there is no way to do it in 0.4.0.
By looking at the doc of the latest version (0.5), I found support for batch diagonal and einsum('bii->bi', (batch_matrix,))
.
Just as a note regarding batch diag. Let’s assume that we want to find the diagonals of N matrices of the same sizes (R, C) stacked in a variable x
of size (N, R, C).
N = 2
R = 5
C = 3
x = torch.rand(N, R, C)
One way to get the diagonals is as follows:
x[[...] + [torch.arange(min(x.size()[-2:]), dtype=torch.long)] * 2]
However, this is very inefficient. The function that is implemented in pytorch versions strictly greater than 0.4.0:
torch.diagonal(x, dim1=-2, dim2=-1)
A temporary equivalent solution for pytorch 0.4.0 that only makes a view of the diagonals:
torch.from_numpy(x.numpy().diagonal(axis1=-2, axis2=-1))
Here is a comparison for the running time:
%timeit x[[...] + [torch.arange(min(x.size()[-2:]), dtype=torch.long)] * 2]
# 23.8 µs ± 687 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit torch.from_numpy(x.numpy().diagonal(axis1=-2, axis2=-1))
# 3.02 µs ± 102 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
We, also, note that it is slightly more efficient to stack the matrices as (R, C, N) where axis1=0
and axis2=1
:
x = torch.rand(R, C, N)
%timeit x[[torch.arange(min(x.size()[:2]), dtype=torch.long)] * 2 + [...]].t()
# 19.6 µs ± 662 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit torch.from_numpy(x.numpy().diagonal(axis1=0, axis2=1))
# 2.58 µs ± 73.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)