Creating a new tensor using multiple tensors and label information

Let’s assume that I have 10 different tensors of size 128x10 (128 is the batch size and 10 is the number of classes). The name of this tensors are outputs_0 , outputs_1, … , outputs_9.

And I have another tensor of size 128x10 which is a one-hot encoded value of the labels of my batch input data. Name of this tensor is one_hot.

Now, what I am trying to do is to pick one of the outputs tensor and its associated row based on the one_hot tensor. The final output tensor that I would like to create is again of size 128x10.

Lets assume that the label of the first data point x_0 is 0. Then the first row of my final output tensor should be the first row of outputs_0 tensor. If the label of the second data point x_1 is 7, then the second row of my final output tensor should be the second row of outputs_7 tensor. If the label of the third data point x_2 is 5, then the third row of my final output tensor should be the third row of outputs_5 tensor. And so on…

Would you please help to create such a final output tensor using 10 tensors and 1 one-hot label tensor?

 for X, y in loader:
        X, y = X.to(device), y.to(device)

        item_count = X.shape[0]

        outputs = model(X)

        print("outputs shape ", outputs.shape)

        outputs_0 = my_act_func0(outputs) 
        outputs_1 = my_act_func1(outputs) 
        outputs_2 = my_act_func2(outputs) 
        outputs_3 = my_act_func3(outputs) 
        outputs_4 = my_act_func4(outputs) 
        outputs_5 = my_act_func5(outputs) 
        outputs_6 = my_act_func6(outputs) 
        outputs_7 = my_act_func7(outputs) 
        outputs_8 = my_act_func8(outputs) 
        outputs_9 = my_act_func9(outputs) 

        print("outputs_0 shape ", outputs_0.shape)

        all_outputs = torch.zeros([10,item_count,10])

        all_outputs[0] = outputs_0
        all_outputs[1] = outputs_1
        all_outputs[2] = outputs_2
        all_outputs[3] = outputs_3
        all_outputs[4] = outputs_4
        all_outputs[5] = outputs_5
        all_outputs[6] = outputs_6
        all_outputs[7] = outputs_7
        all_outputs[8] = outputs_8
        all_outputs[9] = outputs_9

        print("all_outputs shape ", all_outputs.shape)

        #print("output shape ", outputs.shape)
        #print("y ", y.shape)

        one_hot = torch.nn.functional.one_hot(y,num_classes=10)
        one_hot = one_hot.to(torch.float32)
        print("one hot", one_hot )
        print("one hot shape", one_hot.shape )

The output of above code piece is like this:

outputs shape torch.Size([128, 10])
outputs_0 shape torch.Size([128, 10])
all_outputs shape torch.Size([10, 128, 10])
one hot tensor([[1., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
…,
[0., 0., 0., …, 0., 0., 1.],
[0., 1., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.]], device=‘cuda:0’)
one hot shape torch.Size([128, 10])

Assume that the number of batch is 1 (I didn’t understand what you are doing in multiple batches)

outputs_list = [outputs_0, outputs_1,..., outputs_9] # each outputs_x has (1 x 10)
all_outputs = [outputs_0]
for i in range(item_count - 1):
    label = int(torch.argmax(all_outputs[-1], dim=1))
    all_outputs.append(outputs_list[label])

I have 10 outputs_x tensors and also a one-hot tensor ( or label tensor “y”).

I need to pick the correct tensor and row based on one_hot tensor information or label information (y tensor) and use them to construct a new tensor as I have tried to explain in my initial post.

And also the batch size is very important. The suggested solution should work with different batch sizes.

In my case, I chose to use a batch size of 128, so the dimensions of tensors are 128x10.

This seems to work for my need:

    for X, y in loader:
        X, y = X.to(device), y.to(device)

        item_count = X.shape[0]

        outputs = model(X)

        outputs_0 = my_act_func0(outputs) 
        outputs_1 = my_act_func1(outputs) 
        outputs_2 = my_act_func2(outputs) 
        outputs_3 = my_act_func3(outputs) 
        outputs_4 = my_act_func4(outputs) 
        outputs_5 = my_act_func5(outputs) 
        outputs_6 = my_act_func6(outputs) 
        outputs_7 = my_act_func7(outputs) 
        outputs_8 = my_act_func8(outputs) 
        outputs_9 = my_act_func9(outputs) 

        outputs_list = [outputs_0, outputs_1, outputs_2, outputs_3, outputs_4, outputs_5, outputs_6, outputs_7, outputs_8, outputs_9] 

        final_outputs = []

        y_list = y.tolist()

        for count, value in enumerate(y_list):
          final_outputs.append(outputs_list[value][count])
        
        final_outputs = torch.stack(final_outputs)
        ```