I need to get a diagonal stripe of the matrix. Say, I have a matrix of size KxN, where K and N are arbitrary sizes and K>N. Given a matrix:
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
[ 9 10 11]]
From it I would need to extract a diagonal stripe, in this case, a matrix MxV size that is created by truncating the original one:
[[ 0 x x]
[ 3 4 x]
[ x 7 8]
[ x x 11]]
So the result matrix is:
[[ 0 4 8]
[ 3 7 11]]
Now, an additional problem that I face is that I have a tensor of, say, size [150, 182, 91], the first part is just the batch size while the matrix I am interested in is the 182x91 one. (sizes here are just examples)
I need to run a function on the 182x91 matrix for each of the 50 dimensions separately.
Well, I have a solution to this actually, but the code makes my whole model at least 20 times slower (credits for the code to layog from StackOverflow). Here is the piece of code that would do it in Pytorch 0.4 (cuda 9.1):
In [1]: import torch
In [2]: def stripe(a):
...: i, j = a.size()
...: assert(i > j)
...: out = torch.zeros((i - j, j))
# this is probably the bottleneck part
...: for diag in range(0, i - j):
...: out[diag] = torch.diag(a, -diag)
...: return out
In [3]: a = torch.randn((182, 91)).cuda()
In [5]: output = stripe(a)
In [6]: output.size()
Out[6]: torch.Size([91, 91])
In [7]: a = torch.randn((150, 182, 91))
# we map the stripe function over the first dimension of the tensor using torch.unbind
# this is a potential bottleneck
In [8]: output = list(map(stripe, torch.unbind(a, 0)))
In [9]: output = torch.stack(output, 0)
In [10]: output.size()
Out[10]: torch.Size([150, 91, 91])
Any potential for optimization here?
For reference, if that helps, here is the same implementation done in numpy:
>>> import numpy as np
>>>
>>> def stripe(a):
... a = np.asanyarray(a)
... i, j = a.shape
... assert i >= j
... k, l = a.strides
... return np.lib.stride_tricks.as_strided(a, (i-j+1, j), (k, k+l))
...
>>> a = np.arange(24).reshape(6, 4)
>>> a
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]])
>>> stripe(a)
array([[ 0, 5, 10, 15],
[ 4, 9, 14, 19],
[ 8, 13, 18, 23]])