Ok, so I couldn’t find a super easy or efficient way to do this, and if people have better solutions I’d love to see them, but this is what I came up with:
def block_diag(x, n):
"""Repeats a tensor diagonally n times as specified by @zfzhang """
outputs = []
for row in x:
out = torch.zeros(n, n + x.shape[1] - 1)
for ii, elem in enumerate(row):
d = torch.diag(elem.expand(n))
padded = torch.nn.functional.pad(d, pad=(ii, x.shape[1] - ii - 1))
out = out + padded
outputs.append(out)
return torch.cat(outputs)
This method works, but is… not fast:
In [1]: x = torch.rand(10, 10)
In [2]: %timeit block_diag(x, 2)
1.78 ms ± 53.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [3]: %timeit block_diag(x, 10)
1.76 ms ± 60.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [4]: x = torch.rand(50, 50)
In [5]: %timeit block_diag(x, 2)
43.2 ms ± 2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [6]: %timeit block_diag(x, 10)
44.9 ms ± 1.65 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
The method should work with autograd, but may be infeasibly slow, depending on the size of your matrix! Hope that helps a bit. Again, if someone else has a better solution, please share, I’m very curious!