Block_diag for tensors of more than 2 dimensions

I am representing an adjacency matrix as a tensor of 4 dimensions as follows: (batch_size, n_nodes, n_nodes, n_edge_features). Say I have 2 adjacency matrices of subgraphs t1.shape = (1, 46, 46, 8) and t2.shape = (1, 4, 4, 8). I want to ‘merge’ the two graphs to get an adjacency matrix t3 = (1, 50, 50, 8). In other words, I want to stack the two tensors on the diagonal (using smthg like block_diag) with off-diagonal entries = 0.

Unfortunately torch.block_diag only works for 2D tensors, and I am not sure how to use other concatenation/stacking function (e.g. torch.stack, torch.cat) to achieve this.

I came up with this, any improvements are welcome:

      # t1.shape = (bs, n_nodes, n_nodes, feature_size) = (1, 2, 2, 2)
      t1 = torch.Tensor([[[1, 1], [2, 2]], [[3, 3], [4, 4]]])\
                .unsqueeze(dim=0) 
      # t1.shape = (bs, n_nodes, n_nodes, feature_size) = (1, 3, 3, 2)
      t2 = torch.Tensor([[[5, 5], [6, 6], [7, 7]], 
                         [[8, 8], [9, 9], [10, 10]], 
                         [[11, 11], [12, 12], [13, 13]]])\
                .unsqueeze(dim=0)

      # t3.shape = (1, 5, 5, 2)
      t3 = torch.zeros(t1.shape[0], t1.shape[1]+t2.shape[1], 
                       t1.shape[2]+t2.shape[2], t1.shape[-1])
 
      t3[:,:t1.shape[1],:,:][:,:,:t1.shape[2],:] = t1
      t3[:,t1.shape[1]:t1.shape[1]+t2.shape[1],:,:]\
        [:,:,t1.shape[2]:t1.shape[2]+t1.shape[2],:] = t2