Stacking tensors from multiple tensors in a specified indexing order

Hi,

I want to stack two tensors along a dimension, but not sequentially. Rather I want to specify the particular indexing of the stacking along that dimension. As a concrete example, I will show how this should work for just two tensors A and B:

A=torch.tensor([[1,2,3],
                [4,5,6],
                [7,8,9]])
B=torch.tensor([[10,11,12],
                [13,14,15],
                [16,17,18],
                [19,20,21]])

Stacking A and B along dim = 0 with this indexing for interleaving A with B:

stack_indexing_A = torch.tensor([1,3,6])

should yield this output:

result = torch.tensor([[10,11,12], 
                       [1,2,3],
                       [13,14,15],
                       [4,5,6],
                       [16,17,18],
                       [19,20,21],
                       [7,8,9]])

Basically I interleave A with B in a specified index ordering for one of the two (A in this case). This needs to be done without allocating a torch.empty tensor and filling it out because both these tensors are part of the computational graph of my model and assigining a new torch empty tensor might cause it to be rendered out of the computation graph. I am trying to do this with torch gather but as far as I understand that only works with a single tensor. Thanks for any help!

Why would this be the case?
It works fine in my minimal example:

import torch

A = torch.tensor([[1,2,3],
                  [4,5,6],
                  [7,8,9]]).float().requires_grad_()
B = torch.tensor([[10,11,12],
                  [13,14,15],
                  [16,17,18],
                  [19,20,21]]).float().requires_grad_()

stack_indexing_A = torch.tensor([1,3,6])
stack_mask_A = torch.tensor([False]*(A.size(0)+B.size(0)))
stack_mask_A[stack_indexing_A] = True

C = torch.empty(A.size(0)+B.size(0), A.size(1))
C[stack_mask_A] = A
C[~stack_mask_A] = B

print(C)
# tensor([[10., 11., 12.],
#         [ 1.,  2.,  3.],
#         [13., 14., 15.],
#         [ 4.,  5.,  6.],
#         [16., 17., 18.],
#         [19., 20., 21.],
#         [ 7.,  8.,  9.]], grad_fn=<IndexPutBackward0>)

C.mean().backward()
print(A.grad)
# tensor([[0.0476, 0.0476, 0.0476],
#         [0.0476, 0.0476, 0.0476],
#         [0.0476, 0.0476, 0.0476]])
print(B.grad)
# tensor([[0.0476, 0.0476, 0.0476],
#         [0.0476, 0.0476, 0.0476],
#         [0.0476, 0.0476, 0.0476],
#         [0.0476, 0.0476, 0.0476]])