I have a tensor of shape (batch_size, seq_len)
. From this, I want to make a tensor of shape (bs, seq_len+1, seq_len+1)
where each tensor corresponds to a single batch is an off-diagonal matrix with entries from given tensor.
More precisely, from a given tensor
x = [[1, 2, 3], [4, 5, 6]],
I want to make a function that returns
y = [[[0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 3], [0, 0, 0, 0]],
[[0, 4, 0, 0], [0, 0, 5, 0], [0, 0, 0, 6], [0, 0, 0, 0]]].
I can’t directly apply torch.diag
since it gives diagonals of given tensor when the tensor has dimension > 1. I think it is possible to do with torch.diagflat
, but it is highly memory-inefficient since it constructs a tensor of shape (batch_size * seq_len - 1, batch_size * seq_len - 1)
.
EDIT: I just found that, for my purpose, it is enough to construct following function: transform
x = [[1, 2, 3], [4, 5, 6]]
to
y = [[[1, 0, 0], [0, 2, 0], [0, 0, 3]],
[[4, 0, 0], [0, 5, 0], [0, 0, 6]]]
i.e. transform each row vectors into diagonal matrices, not off-diagonals.