Mask data from a tensor

Hi,
I have a tensor a and want to select some data from it based on another tensor which is -1, 0, or 1. for example:

a = torch.randn([7, 4, 30])
b = tensor([[1, 0, 0, 0],
            [0, 1, 1, 1],
            [1, 0, 0, 0],
            [0, 1, 1, 1],
            [-1, 0, 0, 0],
            [-1,-1, 1, 1],
            [-1,-1, 0,-1]], dtype=torch.int32)

I want to select the data from each column of a where b==0 and b!=-1 or the same thing for b==1. I want to do it for each column and keep the data in a same column and then for example mask the rest of the column with zero or sth else. Finaly the shape of result would be [xx, 4, 30] where xx is sth less than 7 based on index values in that column. I tried to create mask for for example b==0 and use 'masked_select` but it gives me a 1-D tensor but I want the shape to be [xx, 4, 30].

Does anyone know a solution in Pytorch? Number of columns and rows are big in mycode and I don’t want to use for loop to do this.

Update:

a = torch.randn([6, 4, 30])
b = tensor([[1, 0, 0, 0],
            [0, 1, 1, 1],
            [1, 0, 0, 0],
            [0, 1, 1, 1],
            [-1, 0, 0, 0],
            [-1,-1, 1, 1]], dtype=torch.int32)

Could you post the desired result for the calculation given your input tensors?
I’m not sure how the [-1, 0, 1] values should be treated at the moment.

let me explain the problem in another way. I have a multi agent environment with two agents. I have one tensor for agent_index which has only 0 and 1 and one tensor for done flag. The b tensor is calculated as follows:

agent_index = agent_index + 1 #to have 1, 2
agent_index = (1 - done) * agent_index  # 0 or 1 or 2 -> 0 for done
agent_index = agent_index - 1

the b tensor is equal to the agent_index tensor. I want to separate the data from agents using the b tensor as mask. Each row of the b tensor is a trajectory from one of many parallel envs.

one example of data:

b = agent_index = 
tensor([[ 1,  0,  1,  0],
        [ 0,  1,  0,  1],
        [ 1,  0,  1, -1],
        [ 0, -1,  0, -1],
        [ 1, -1,  1, -1],
        [-1, -1,  0, -1]], dtype=torch.int32)

a = observation = 
tensor([[[-1.7150e-01, -4.8424e-01,  8.8617e+00,  2.3157e+01, -8.4234e-03,
          -2.2124e-01, -3.6421e-01,  3.5387e+00,  1.5165e-01,  3.2926e-02,
           4.2830e-02,  0.0000e+00,  9.0558e-03,  2.3515e-01,  6.3245e-04,
           1.3912e-02,  1.1196e-01,  1.7685e+00, -2.2143e-01,  2.8138e+01,
          -3.0058e-01,  2.8635e+01, -2.6145e-01,  3.3667e+01, -1.9383e-01,
           3.3245e+01,  1.6753e-03,  8.9053e-03, -1.4090e-01,  8.9662e-03],
         [-2.2966e-01, -2.2966e-01,  3.6587e+01,  3.6587e+01,  1.2437e-02,
          -3.9594e-01,  1.9770e+00, -1.2419e+00, -4.0993e-01, -1.8288e-01,
          -1.2287e-01,  0.0000e+00, -1.5647e-01, -1.8838e-02, -1.4403e-01,
          -4.1478e-01, -3.0608e+00, -2.3535e+00, -3.0667e-01,  2.9029e+01,
          -3.8337e-01,  2.9532e+01, -3.4508e-01,  3.4562e+01, -2.7926e-01,
           3.4133e+01,  4.9681e-03,  3.0067e-02,  3.6508e-02, -3.0510e-03],
         [ 1.0990e-01, -3.9639e-01,  4.1975e+00,  1.4005e+01, -8.5840e-02,
           2.7769e-01, -3.4136e-02, -1.0512e+00, -2.7759e-01,  2.1368e-01,
           1.8352e-01,  0.0000e+00,  7.7976e-02, -3.3801e-01, -7.8642e-03,
          -6.0324e-02, -3.8388e-01, -2.0709e+00,  4.8057e-03,  2.4067e+01,
          -9.0611e-02,  2.4146e+01, -7.6136e-02,  2.9309e+01,  2.4385e-03,
           2.9243e+01, -1.7170e-04,  3.3451e-03,  6.6576e-02,  4.2948e-03],
         [-1.7396e-01, -1.7396e-01,  2.7653e+01,  2.7653e+01, -4.1336e-02,
          -9.6437e-02, -4.6129e-01, -2.1930e-01, -1.0034e-01, -1.9808e-02,
          -2.3761e-02,  0.0000e+00,  1.3177e-01, -1.0492e-01,  9.0433e-02,
          -2.0136e-01, -5.9576e-02,  3.6985e+00, -2.4864e-01,  2.0624e+01,
          -3.6042e-01,  2.0524e+01, -3.5792e-01,  2.5700e+01, -2.6856e-01,
           2.5780e+01, -5.9599e-03,  2.8747e-02, -1.1965e-01, -1.9982e-02]],
        [[-1.3764e-01, -1.3764e-01,  3.7657e+01,  3.7657e+01,  5.4948e-02,
           1.7581e-01,  9.1497e-01,  1.5440e+00,  2.1075e-01,  3.7513e-02,
           6.2444e-02,  0.0000e+00, -6.3372e-02, -3.9705e-01, -8.4234e-03,
          -2.2124e-01, -3.6421e-01,  3.5387e+00, -2.2035e-01,  2.8098e+01,
          -2.9856e-01,  2.8700e+01, -2.5260e-01,  3.3676e+01, -1.8546e-01,
           3.3164e+01, -1.4493e-03,  3.2164e-02,  1.8048e-01, -1.7941e-02],
         [-6.5071e-02, -3.5040e-01,  1.0754e+01,  2.6788e+01, -1.1363e-01,
          -3.4749e-01,  1.5066e+00,  3.9758e+00,  2.7810e-01,  2.5307e-01,
           2.2713e-01,  0.0000e+00,  1.2607e-01, -4.8451e-02,  1.2437e-02,
          -3.9594e-01,  1.9770e+00, -1.2419e+00, -1.4203e-01,  2.8831e+01,
          -2.1622e-01,  2.9604e+01, -1.6057e-01,  3.4465e+01, -9.6004e-02,
           3.3804e+01,  4.5614e-03,  3.1811e-02, -3.6468e-02,  7.7028e-03],
         [ 1.0133e-01,  1.0133e-01,  3.8128e+01,  3.8128e+01,  2.2079e-02,
          -6.3589e-02, -3.2411e-01,  2.6529e+00, -6.5885e-03,  1.4360e-01,
           1.2226e-01,  0.0000e+00, -1.0792e-01,  3.4128e-01, -8.5840e-02,
           2.7769e-01, -3.4136e-02, -1.0512e+00,  3.3123e-02,  2.4070e+01,
          -6.2227e-02,  2.4165e+01, -4.6598e-02,  2.9324e+01,  3.1937e-02,
           2.9247e+01,  8.6151e-03, -3.3164e-02, -8.1405e-02, -4.3179e-03],
         [ 2.6432e-01, -3.7339e-02,  1.1069e+01,  2.7205e+01, -4.0951e-03,
          -1.0512e-01, -4.5474e-01, -1.7421e+00,  1.4921e-01, -1.4148e-01,
          -1.6022e-01,  0.0000e+00, -3.7240e-02,  8.6861e-03, -4.1336e-02,
          -9.6437e-02, -4.6129e-01, -2.1930e-01,  1.2671e-01,  2.0225e+01,
           1.9799e-02,  2.0906e+01,  8.9292e-02,  2.5825e+01,  1.7677e-01,
           2.5276e+01,  1.8265e-03,  5.0008e-03,  1.3965e-02, -2.2795e-03]],
        [[-2.4649e-01, -5.5179e-01,  8.8924e+00,  2.3261e+01,  5.0892e-02,
           1.1412e-01,  1.3875e+00,  2.2890e+00, -1.6297e-01, -1.0621e-01,
          -9.4548e-02,  0.0000e+00,  4.0567e-03,  6.1696e-02,  5.4948e-02,
           1.7581e-01,  9.1497e-01,  1.5440e+00, -2.8891e-01,  2.8142e+01,
          -3.6700e-01,  2.8744e+01, -3.2110e-01,  3.3719e+01, -2.5405e-01,
           3.3207e+01, -2.4874e-03, -1.0706e-02, -9.8496e-02,  1.0876e-02],
         [-2.2715e-01, -2.2715e-01,  3.6628e+01,  3.6628e+01, -2.1661e-01,
          -4.4198e-01, -2.7478e+00,  8.0418e-02, -5.0902e-01, -1.8981e-01,
          -1.0664e-01,  0.0000e+00,  1.0298e-01,  9.4488e-02, -1.1363e-01,
          -3.4749e-01,  1.5066e+00,  3.9758e+00, -3.0099e-01,  2.9054e+01,
          -3.7839e-01,  2.9463e+01, -3.4607e-01,  3.4536e+01, -2.7987e-01,
           3.4188e+01,  1.5170e-02,  4.4323e-02, -9.8671e-02,  7.0515e-03],
         [ 1.8373e-01, -3.3716e-01,  4.2431e+00,  1.4035e+01, -1.8539e-02,
           3.1327e-01,  5.8566e-01,  1.1661e+00, -3.3689e-01,  1.0228e-01,
           6.3571e-02,  0.0000e+00,  4.0617e-02, -3.7686e-01,  2.2079e-02,
          -6.3589e-02, -3.2411e-01,  2.6529e+00,  5.8397e-02,  2.4133e+01,
          -3.6923e-02,  2.4155e+01, -2.6812e-02,  2.9325e+01,  5.1685e-02,
           2.9306e+01, -1.4387e-03,  1.1748e-02,  4.6042e-02,  3.9803e-03],
         [-1.7667e-01, -1.7667e-01,  2.7667e+01,  2.7667e+01, -2.1033e-02,
          -2.0172e-01,  2.5859e-01, -2.0971e+00, -2.4106e-01, -2.8707e-02,
          -2.7080e-02,  0.0000e+00,  1.6937e-02,  9.6600e-02, -4.0951e-03,
          -1.0512e-01, -4.5474e-01, -1.7421e+00, -2.4828e-01,  2.0666e+01,
          -3.5982e-01,  2.0500e+01, -3.6311e-01,  2.5676e+01, -2.7385e-01,
           2.5809e+01, -1.5866e-03,  9.1029e-03, -6.4818e-02, -1.3253e-02]],
        [[-1.6491e-01, -1.6491e-01,  3.7640e+01,  3.7640e+01, -3.1784e-02,
           3.2591e-01, -1.4907e+00,  1.2319e+00,  3.4916e-01, -2.2137e-02,
          -2.6655e-04,  0.0000e+00,  8.2676e-02, -2.1179e-01,  5.0892e-02,
           1.1412e-01,  1.3875e+00,  2.2890e+00, -2.5406e-01,  2.8060e+01,
          -3.3062e-01,  2.8807e+01, -2.7523e-01,  3.3688e+01, -2.0896e-01,
           3.3051e+01, -8.8114e-04,  6.7067e-04,  1.4919e-01, -2.4587e-02],
         [ 1.0186e-02, -2.8051e-01,  1.0741e+01,  2.6717e+01,  1.4444e-02,
           5.5334e-02,  7.8251e-01,  3.0108e+00, -1.0133e-01,  3.6721e-01,
           3.3543e-01,  0.0000e+00, -2.3106e-01, -4.9731e-01, -2.1661e-01,
          -4.4198e-01, -2.7478e+00,  8.0418e-02, -7.5118e-02,  2.8847e+01,
          -1.4950e-01,  2.9603e+01, -9.4973e-02,  3.4477e+01, -3.0311e-02,
           3.3831e+01,  1.1581e-02,  4.3092e-02,  5.1529e-03, -1.1217e-03],
         [ 1.1156e-01,  1.1156e-01,  3.8119e+01,  3.8119e+01, -7.5301e-02,
           2.0663e-01, -6.6783e-01,  1.7858e+00,  2.4426e-01,  1.1629e-01,
           1.0294e-01,  0.0000e+00,  5.6763e-02,  1.0665e-01, -1.8539e-02,
           3.1327e-01,  5.8566e-01,  1.1661e+00,  4.8594e-02,  2.4146e+01,
          -4.6770e-02,  2.4119e+01, -4.0435e-02,  2.9293e+01,  3.8092e-02,
           2.9315e+01,  4.2772e-03, -2.8184e-02, -1.2063e-01, -1.1147e-02],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],
        [[-3.1301e-01, -6.1151e-01,  8.9435e+00,  2.3398e+01,  1.5990e-02,
           7.1948e-02, -8.5906e-01, -2.8210e+00, -1.1974e-02, -2.3011e-01,
          -2.0958e-01,  0.0000e+00, -4.7774e-02,  2.5396e-01, -3.1784e-02,
           3.2591e-01, -1.4907e+00,  1.2319e+00, -3.4843e-01,  2.8158e+01,
          -4.2567e-01,  2.8831e+01, -3.7520e-01,  3.3763e+01, -3.0863e-01,
           3.3190e+01, -1.5298e-03, -2.6828e-02, -6.5327e-02,  9.3303e-03],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 2.4919e-01, -2.8348e-01,  4.2938e+00,  1.4080e+01, -6.8882e-02,
           1.6863e-01,  5.4997e-01, -3.6955e+00, -1.0929e-01,  6.7108e-03,
          -4.4881e-02,  0.0000e+00, -6.4197e-03,  3.7995e-02, -7.5301e-02,
           2.0663e-01, -6.6783e-01,  1.7858e+00,  1.0623e-01,  2.4171e+01,
           1.0990e-02,  2.4159e+01,  1.8453e-02,  2.9331e+01,  9.6896e-02,
           2.9341e+01,  3.9892e-03, -1.1034e-02,  2.4403e-02,  2.4460e-03],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],
        [[-2.2188e-01, -2.2188e-01,  3.7635e+01,  3.7635e+01, -1.4795e-02,
           2.5464e-01,  3.0827e-01, -2.5479e+00,  2.0137e-01, -1.4957e-01,
          -1.2263e-01,  0.0000e+00,  3.0785e-02, -1.8270e-01,  1.5990e-02,
           7.1948e-02, -8.5906e-01, -2.8210e+00, -3.1730e-01,  2.8066e+01,
          -3.9192e-01,  2.8945e+01, -3.2803e-01,  3.3721e+01, -2.6280e-01,
           3.2970e+01, -4.4306e-03, -1.3863e-02,  1.3256e-01, -2.9475e-02],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 8.2153e-02,  8.2153e-02,  3.8075e+01,  3.8075e+01, -1.8012e-01,
           5.1354e-01, -1.0474e+00,  3.4399e+00,  6.0403e-01,  1.7191e-01,
           1.8520e-01,  0.0000e+00,  1.1124e-01, -3.4491e-01, -6.8882e-02,
           1.6863e-01,  5.4997e-01, -3.6955e+00,  2.2358e-02,  2.4192e+01,
          -7.2885e-02,  2.4063e+01, -7.4350e-02,  2.9239e+01,  4.1287e-03,
           2.9345e+01,  6.4416e-03, -2.7741e-02, -1.0866e-01, -1.4951e-02],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]])

observation0 = 
tensor([[[-1.3764e-01, -1.3764e-01,  3.7657e+01,  3.7657e+01,  5.4948e-02,
           1.7581e-01,  9.1497e-01,  1.5440e+00,  2.1075e-01,  3.7513e-02,
           6.2444e-02,  0.0000e+00, -6.3372e-02, -3.9705e-01, -8.4234e-03,
          -2.2124e-01, -3.6421e-01,  3.5387e+00, -2.2035e-01,  2.8098e+01,
          -2.9856e-01,  2.8700e+01, -2.5260e-01,  3.3676e+01, -1.8546e-01,
           3.3164e+01, -1.4493e-03,  3.2164e-02,  1.8048e-01, -1.7941e-02],
         [-2.2715e-01, -2.2715e-01,  3.6628e+01,  3.6628e+01, -2.1661e-01,
          -4.4198e-01, -2.7478e+00,  8.0418e-02, -5.0902e-01, -1.8981e-01,
          -1.0664e-01,  0.0000e+00,  1.0298e-01,  9.4488e-02, -1.1363e-01,
          -3.4749e-01,  1.5066e+00,  3.9758e+00, -3.0099e-01,  2.9054e+01,
          -3.7839e-01,  2.9463e+01, -3.4607e-01,  3.4536e+01, -2.7987e-01,
           3.4188e+01,  1.5170e-02,  4.4323e-02, -9.8671e-02,  7.0515e-03],
         [ 1.0133e-01,  1.0133e-01,  3.8128e+01,  3.8128e+01,  2.2079e-02,
          -6.3589e-02, -3.2411e-01,  2.6529e+00, -6.5885e-03,  1.4360e-01,
           1.2226e-01,  0.0000e+00, -1.0792e-01,  3.4128e-01, -8.5840e-02,
           2.7769e-01, -3.4136e-02, -1.0512e+00,  3.3123e-02,  2.4070e+01,
          -6.2227e-02,  2.4165e+01, -4.6598e-02,  2.9324e+01,  3.1937e-02,
           2.9247e+01,  8.6151e-03, -3.3164e-02, -8.1405e-02, -4.3179e-03],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],
        [[-1.6491e-01, -1.6491e-01,  3.7640e+01,  3.7640e+01, -3.1784e-02,
           3.2591e-01, -1.4907e+00,  1.2319e+00,  3.4916e-01, -2.2137e-02,
          -2.6655e-04,  0.0000e+00,  8.2676e-02, -2.1179e-01,  5.0892e-02,
           1.1412e-01,  1.3875e+00,  2.2890e+00, -2.5406e-01,  2.8060e+01,
          -3.3062e-01,  2.8807e+01, -2.7523e-01,  3.3688e+01, -2.0896e-01,
           3.3051e+01, -8.8114e-04,  6.7067e-04,  1.4919e-01, -2.4587e-02],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 1.1156e-01,  1.1156e-01,  3.8119e+01,  3.8119e+01, -7.5301e-02,
           2.0663e-01, -6.6783e-01,  1.7858e+00,  2.4426e-01,  1.1629e-01,
           1.0294e-01,  0.0000e+00,  5.6763e-02,  1.0665e-01, -1.8539e-02,
           3.1327e-01,  5.8566e-01,  1.1661e+00,  4.8594e-02,  2.4146e+01,
          -4.6770e-02,  2.4119e+01, -4.0435e-02,  2.9293e+01,  3.8092e-02,
           2.9315e+01,  4.2772e-03, -2.8184e-02, -1.2063e-01, -1.1147e-02],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]])

observation1 = 
tensor([[[-2.4649e-01, -5.5179e-01,  8.8924e+00,  2.3261e+01,  5.0892e-02,
           1.1412e-01,  1.3875e+00,  2.2890e+00, -1.6297e-01, -1.0621e-01,
          -9.4548e-02,  0.0000e+00,  4.0567e-03,  6.1696e-02,  5.4948e-02,
           1.7581e-01,  9.1497e-01,  1.5440e+00, -2.8891e-01,  2.8142e+01,
          -3.6700e-01,  2.8744e+01, -3.2110e-01,  3.3719e+01, -2.5405e-01,
           3.3207e+01, -2.4874e-03, -1.0706e-02, -9.8496e-02,  1.0876e-02],
         [-6.5071e-02, -3.5040e-01,  1.0754e+01,  2.6788e+01, -1.1363e-01,
          -3.4749e-01,  1.5066e+00,  3.9758e+00,  2.7810e-01,  2.5307e-01,
           2.2713e-01,  0.0000e+00,  1.2607e-01, -4.8451e-02,  1.2437e-02,
          -3.9594e-01,  1.9770e+00, -1.2419e+00, -1.4203e-01,  2.8831e+01,
          -2.1622e-01,  2.9604e+01, -1.6057e-01,  3.4465e+01, -9.6004e-02,
           3.3804e+01,  4.5614e-03,  3.1811e-02, -3.6468e-02,  7.7028e-03],
         [ 1.8373e-01, -3.3716e-01,  4.2431e+00,  1.4035e+01, -1.8539e-02,
           3.1327e-01,  5.8566e-01,  1.1661e+00, -3.3689e-01,  1.0228e-01,
           6.3571e-02,  0.0000e+00,  4.0617e-02, -3.7686e-01,  2.2079e-02,
          -6.3589e-02, -3.2411e-01,  2.6529e+00,  5.8397e-02,  2.4133e+01,
          -3.6923e-02,  2.4155e+01, -2.6812e-02,  2.9325e+01,  5.1685e-02,
           2.9306e+01, -1.4387e-03,  1.1748e-02,  4.6042e-02,  3.9803e-03],
         [ 2.6432e-01, -3.7339e-02,  1.1069e+01,  2.7205e+01, -4.0951e-03,
          -1.0512e-01, -4.5474e-01, -1.7421e+00,  1.4921e-01, -1.4148e-01,
          -1.6022e-01,  0.0000e+00, -3.7240e-02,  8.6861e-03, -4.1336e-02,
          -9.6437e-02, -4.6129e-01, -2.1930e-01,  1.2671e-01,  2.0225e+01,
           1.9799e-02,  2.0906e+01,  8.9292e-02,  2.5825e+01,  1.7677e-01,
           2.5276e+01,  1.8265e-03,  5.0008e-03,  1.3965e-02, -2.2795e-03]],
        [[-3.1301e-01, -6.1151e-01,  8.9435e+00,  2.3398e+01,  1.5990e-02,
           7.1948e-02, -8.5906e-01, -2.8210e+00, -1.1974e-02, -2.3011e-01,
          -2.0958e-01,  0.0000e+00, -4.7774e-02,  2.5396e-01, -3.1784e-02,
           3.2591e-01, -1.4907e+00,  1.2319e+00, -3.4843e-01,  2.8158e+01,
          -4.2567e-01,  2.8831e+01, -3.7520e-01,  3.3763e+01, -3.0863e-01,
           3.3190e+01, -1.5298e-03, -2.6828e-02, -6.5327e-02,  9.3303e-03],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 2.4919e-01, -2.8348e-01,  4.2938e+00,  1.4080e+01, -6.8882e-02,
           1.6863e-01,  5.4997e-01, -3.6955e+00, -1.0929e-01,  6.7108e-03,
          -4.4881e-02,  0.0000e+00, -6.4197e-03,  3.7995e-02, -7.5301e-02,
           2.0663e-01, -6.6783e-01,  1.7858e+00,  1.0623e-01,  2.4171e+01,
           1.0990e-02,  2.4159e+01,  1.8453e-02,  2.9331e+01,  9.6896e-02,
           2.9341e+01,  3.9892e-03, -1.1034e-02,  2.4403e-02,  2.4460e-03],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]])

The observations are for agent0 and agent1 in the env and I put zero for done steps. It is also possible to apply done mask to observation first and then just work with a mask with 0 and 1 values I think.

Here is a naive solution:

        t, b, f = observation.shape

        observation0 = torch.zeros([int(t/2), b, f])
        observation1 = torch.zeros_like(observation0)

        for c in range(b): #for loop on B dim
            cnt0 = 0
            cnt1 = 0
            for r in range(1, t):
                if agent_index[r, c] == 0:
                    observation0[cnt0, c, :] = samples.env.observation[r, c, :]
                    cnt0 += 1
                elif agent_index[r, c] == 1:
                    observation1[cnt1, c, :] = samples.env.observation[r, c, :]
                    cnt1 += 1

One point is that if I do not use -1 value in b tensor, number of 1s and 0s in one column will not be equal.
And if it helps, I’m using rlpyt for RL part.

Thank you

How did you calculate the number of rows of observation0 and observation1?
Currently they are set to int(t/2), but what if the agent_index contains only zeros in row0?
This would create an index error or is this use case not possible?

The agent index will sequentially change. It will be 0, 1, 0, 1, 0, 1, … so half of the steps it is 0 and half 1. In this case a.shape[0] should be an even number (I updated the question). The only case that the mentioned sequence will change to sth like 0, 1, 0, 0, 1, 0, … is when the environment is done (which when I use -1 value for done, this case will not happen and I will have a sequence of 0 and 1 until I see -1). So maximum number of rows for each agent is int(t/2). If there is no -1 in that column, it would be half for agent 0 and half for agent 1 and if there is -1 in the column, I will fill it with some default values like zero.

It doesn’t matter if row values are the same or not. These are some parallel environments and the starting agent can be 0 or 1.