Selecting multiple indices

I want to ‘grow’ a matrix using a set of rules. Example of rules: 0->[[1,1,1],[0,0,0],[2,2,2]], 1->[[2,2,2],[2,2,2],[2,2,2]], 2->[[0,0,0],[0,0,0],[0,0,0]].
Example of growing a matrix: [[0]]->[[1,1,1],[0,0,0],[2,2,2]]->[[2,2,2,2,2,2,2,2,2],[2,2,2,2,2,2,2,2,2],[2,2,2,2,2,2,2,2,2],[1,1,1,1,1,1,1,1,1],[0,0,0,0,0,0,0,0,0],[2,2,2,2,2,2,2,2,2],[0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0]].

This is the code I’ve been trying to get to work in Pytorch

rules = np.random.randint(256,size=(10,256,3,3,3))
rules_tensor = torch.randint(256,size=(10,
            256, 3, 3, 3),
            dtype=torch.uint8, device = torch.device('cuda'))

rules = rules[0]
rules_tensor = rules_tensor[0]

seed = np.array([[128]])
seed_tensor = seed_tensor = torch.cuda.ByteTensor([[128]])

decode = np.empty((3**3, 3**3, 3))
decode_tensor = torch.empty((3**3,
                3**3, 3), dtype=torch.uint8,
                device = torch.device('cuda'))

for i in range(3):
    grow = seed
    grow_tensor = seed_tensor
    for j in range(1,4):
        grow = rules[grow,:,:,i].reshape(3**j,-1)
        grow_tensor = rules_tensor[grow_tensor,:,:,i].reshape(3**j,-1)
        
    decode[..., i] = grow
    decode_tensor[..., i] = grow_tensor

I can’t seem to select indices the same way as in Numpy in this line:
grow = rules[grow,:,:,i].reshape(3**j,-1) Is there a way to do the following in Pytorch?