I am representing an adjacency matrix as a tensor of 4 dimensions as follows: (batch_size, n_nodes, n_nodes, n_edge_features). Say I have 2 adjacency matrices of subgraphs t1.shape = (1, 46, 46, 8) and t2.shape = (1, 4, 4, 8). I want to ‘merge’ the two graphs to get an adjacency matrix t3 = (1, 50, 50, 8). In other words, I want to stack the two tensors on the diagonal (using smthg like block_diag) with off-diagonal entries = 0.
Unfortunately torch.block_diag only works for 2D tensors, and I am not sure how to use other concatenation/stacking function (e.g. torch.stack, torch.cat) to achieve this.