Thank you once again, @ptrblck !
I think I’ve managed to take your example and generalize it for a batch of vectors (and even a batch of multiple vectors, i.e input of shape [B,D,A] where B is the batch, D is the multiple vectors, vector side A, so the input’d be [B, D, A])
The code’s here,
def _merge_on_and_off_diagonal(on_diag, off_diag):
#store output shape, to remove and unsqueeze'd dims
output_shape = (*off_diag.shape[:-1],off_diag.shape[-2])
if(len(on_diag.shape)==1 and len(off_diag.shape)==2):
on_diag=on_diag.unsqueeze(0).unsqueeze(1)
off_diag=off_diag.unsqueeze(0).unsqueeze(1)
elif(len(on_diag.shape)==2 and len(off_diag.shape)==3):
on_diag=on_diag.unsqueeze(1)
off_diag=off_diag.unsqueeze(1)
#reform input on_diag and off_diag to shape
#B = batch, D = number of vector for given batch, A = dimension of vector
#on_diag shape: [B, D, A]
#off_diag shape: [B, D, A, A-1]
if(on_diag.shape[-1] != off_diag.shape[-2]):
raise ValueError("index on_diag.shape[-1] must match off_diag.shape[-2]")
dim=len(on_diag.shape)
tmp = torch.cat((on_diag[:,:,:-1].unsqueeze(dim), \
off_diag.view((*off_diag.shape[0:(dim-1)],off_diag.shape[-1], off_diag.shape[-2]))), dim=dim)
res = torch.cat( (tmp.view(*off_diag.shape[0:(dim-1)], -1), on_diag[:,:,-1].unsqueeze(2)), dim=dim-1 ).view(*off_diag.shape[0:(dim-1)], on_diag.shape[-1], on_diag.shape[-1])
return res.view(output_shape)
It’s not exactly pleasant but it seems to work for varying input sizes!
For example,
diag = torch.tensor([11,22,33,44])
off_diag = torch.tensor([[12,13,14],
[21,23,24],
[31,32,34],
[41,42,43]])
matrix = _merge_on_and_off_diagonal(diag, off_diag)
"""
returns torch.tensor([[11,12,13,14],
[21,22,23,24],
[31,32,33,34],
[41,42,43,44]])
"""
diag = torch.tensor([[11,22,33,44],
[11,22,33,44]])
off_diag = torch.tensor([[[12,13,14],
[21,23,24],
[31,32,34],
[41,42,43]],
[[12,13,14],
[21,23,24],
[31,32,34],
[41,42,43]]])
matrix = _merge_on_and_off_diagonal(diag, off_diag)
"""
returns tensor([[[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]],
[[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]]])
"""
Thank you!