let me explain the problem in another way. I have a multi agent environment with two agents. I have one tensor for agent_index which has only 0 and 1 and one tensor for done
flag. The b
tensor is calculated as follows:
agent_index = agent_index + 1 #to have 1, 2
agent_index = (1 - done) * agent_index # 0 or 1 or 2 -> 0 for done
agent_index = agent_index - 1
the b tensor is equal to the agent_index tensor. I want to separate the data from agents using the b tensor as mask. Each row of the b tensor is a trajectory from one of many parallel envs.
one example of data:
b = agent_index =
tensor([[ 1, 0, 1, 0],
[ 0, 1, 0, 1],
[ 1, 0, 1, -1],
[ 0, -1, 0, -1],
[ 1, -1, 1, -1],
[-1, -1, 0, -1]], dtype=torch.int32)
a = observation =
tensor([[[-1.7150e-01, -4.8424e-01, 8.8617e+00, 2.3157e+01, -8.4234e-03,
-2.2124e-01, -3.6421e-01, 3.5387e+00, 1.5165e-01, 3.2926e-02,
4.2830e-02, 0.0000e+00, 9.0558e-03, 2.3515e-01, 6.3245e-04,
1.3912e-02, 1.1196e-01, 1.7685e+00, -2.2143e-01, 2.8138e+01,
-3.0058e-01, 2.8635e+01, -2.6145e-01, 3.3667e+01, -1.9383e-01,
3.3245e+01, 1.6753e-03, 8.9053e-03, -1.4090e-01, 8.9662e-03],
[-2.2966e-01, -2.2966e-01, 3.6587e+01, 3.6587e+01, 1.2437e-02,
-3.9594e-01, 1.9770e+00, -1.2419e+00, -4.0993e-01, -1.8288e-01,
-1.2287e-01, 0.0000e+00, -1.5647e-01, -1.8838e-02, -1.4403e-01,
-4.1478e-01, -3.0608e+00, -2.3535e+00, -3.0667e-01, 2.9029e+01,
-3.8337e-01, 2.9532e+01, -3.4508e-01, 3.4562e+01, -2.7926e-01,
3.4133e+01, 4.9681e-03, 3.0067e-02, 3.6508e-02, -3.0510e-03],
[ 1.0990e-01, -3.9639e-01, 4.1975e+00, 1.4005e+01, -8.5840e-02,
2.7769e-01, -3.4136e-02, -1.0512e+00, -2.7759e-01, 2.1368e-01,
1.8352e-01, 0.0000e+00, 7.7976e-02, -3.3801e-01, -7.8642e-03,
-6.0324e-02, -3.8388e-01, -2.0709e+00, 4.8057e-03, 2.4067e+01,
-9.0611e-02, 2.4146e+01, -7.6136e-02, 2.9309e+01, 2.4385e-03,
2.9243e+01, -1.7170e-04, 3.3451e-03, 6.6576e-02, 4.2948e-03],
[-1.7396e-01, -1.7396e-01, 2.7653e+01, 2.7653e+01, -4.1336e-02,
-9.6437e-02, -4.6129e-01, -2.1930e-01, -1.0034e-01, -1.9808e-02,
-2.3761e-02, 0.0000e+00, 1.3177e-01, -1.0492e-01, 9.0433e-02,
-2.0136e-01, -5.9576e-02, 3.6985e+00, -2.4864e-01, 2.0624e+01,
-3.6042e-01, 2.0524e+01, -3.5792e-01, 2.5700e+01, -2.6856e-01,
2.5780e+01, -5.9599e-03, 2.8747e-02, -1.1965e-01, -1.9982e-02]],
[[-1.3764e-01, -1.3764e-01, 3.7657e+01, 3.7657e+01, 5.4948e-02,
1.7581e-01, 9.1497e-01, 1.5440e+00, 2.1075e-01, 3.7513e-02,
6.2444e-02, 0.0000e+00, -6.3372e-02, -3.9705e-01, -8.4234e-03,
-2.2124e-01, -3.6421e-01, 3.5387e+00, -2.2035e-01, 2.8098e+01,
-2.9856e-01, 2.8700e+01, -2.5260e-01, 3.3676e+01, -1.8546e-01,
3.3164e+01, -1.4493e-03, 3.2164e-02, 1.8048e-01, -1.7941e-02],
[-6.5071e-02, -3.5040e-01, 1.0754e+01, 2.6788e+01, -1.1363e-01,
-3.4749e-01, 1.5066e+00, 3.9758e+00, 2.7810e-01, 2.5307e-01,
2.2713e-01, 0.0000e+00, 1.2607e-01, -4.8451e-02, 1.2437e-02,
-3.9594e-01, 1.9770e+00, -1.2419e+00, -1.4203e-01, 2.8831e+01,
-2.1622e-01, 2.9604e+01, -1.6057e-01, 3.4465e+01, -9.6004e-02,
3.3804e+01, 4.5614e-03, 3.1811e-02, -3.6468e-02, 7.7028e-03],
[ 1.0133e-01, 1.0133e-01, 3.8128e+01, 3.8128e+01, 2.2079e-02,
-6.3589e-02, -3.2411e-01, 2.6529e+00, -6.5885e-03, 1.4360e-01,
1.2226e-01, 0.0000e+00, -1.0792e-01, 3.4128e-01, -8.5840e-02,
2.7769e-01, -3.4136e-02, -1.0512e+00, 3.3123e-02, 2.4070e+01,
-6.2227e-02, 2.4165e+01, -4.6598e-02, 2.9324e+01, 3.1937e-02,
2.9247e+01, 8.6151e-03, -3.3164e-02, -8.1405e-02, -4.3179e-03],
[ 2.6432e-01, -3.7339e-02, 1.1069e+01, 2.7205e+01, -4.0951e-03,
-1.0512e-01, -4.5474e-01, -1.7421e+00, 1.4921e-01, -1.4148e-01,
-1.6022e-01, 0.0000e+00, -3.7240e-02, 8.6861e-03, -4.1336e-02,
-9.6437e-02, -4.6129e-01, -2.1930e-01, 1.2671e-01, 2.0225e+01,
1.9799e-02, 2.0906e+01, 8.9292e-02, 2.5825e+01, 1.7677e-01,
2.5276e+01, 1.8265e-03, 5.0008e-03, 1.3965e-02, -2.2795e-03]],
[[-2.4649e-01, -5.5179e-01, 8.8924e+00, 2.3261e+01, 5.0892e-02,
1.1412e-01, 1.3875e+00, 2.2890e+00, -1.6297e-01, -1.0621e-01,
-9.4548e-02, 0.0000e+00, 4.0567e-03, 6.1696e-02, 5.4948e-02,
1.7581e-01, 9.1497e-01, 1.5440e+00, -2.8891e-01, 2.8142e+01,
-3.6700e-01, 2.8744e+01, -3.2110e-01, 3.3719e+01, -2.5405e-01,
3.3207e+01, -2.4874e-03, -1.0706e-02, -9.8496e-02, 1.0876e-02],
[-2.2715e-01, -2.2715e-01, 3.6628e+01, 3.6628e+01, -2.1661e-01,
-4.4198e-01, -2.7478e+00, 8.0418e-02, -5.0902e-01, -1.8981e-01,
-1.0664e-01, 0.0000e+00, 1.0298e-01, 9.4488e-02, -1.1363e-01,
-3.4749e-01, 1.5066e+00, 3.9758e+00, -3.0099e-01, 2.9054e+01,
-3.7839e-01, 2.9463e+01, -3.4607e-01, 3.4536e+01, -2.7987e-01,
3.4188e+01, 1.5170e-02, 4.4323e-02, -9.8671e-02, 7.0515e-03],
[ 1.8373e-01, -3.3716e-01, 4.2431e+00, 1.4035e+01, -1.8539e-02,
3.1327e-01, 5.8566e-01, 1.1661e+00, -3.3689e-01, 1.0228e-01,
6.3571e-02, 0.0000e+00, 4.0617e-02, -3.7686e-01, 2.2079e-02,
-6.3589e-02, -3.2411e-01, 2.6529e+00, 5.8397e-02, 2.4133e+01,
-3.6923e-02, 2.4155e+01, -2.6812e-02, 2.9325e+01, 5.1685e-02,
2.9306e+01, -1.4387e-03, 1.1748e-02, 4.6042e-02, 3.9803e-03],
[-1.7667e-01, -1.7667e-01, 2.7667e+01, 2.7667e+01, -2.1033e-02,
-2.0172e-01, 2.5859e-01, -2.0971e+00, -2.4106e-01, -2.8707e-02,
-2.7080e-02, 0.0000e+00, 1.6937e-02, 9.6600e-02, -4.0951e-03,
-1.0512e-01, -4.5474e-01, -1.7421e+00, -2.4828e-01, 2.0666e+01,
-3.5982e-01, 2.0500e+01, -3.6311e-01, 2.5676e+01, -2.7385e-01,
2.5809e+01, -1.5866e-03, 9.1029e-03, -6.4818e-02, -1.3253e-02]],
[[-1.6491e-01, -1.6491e-01, 3.7640e+01, 3.7640e+01, -3.1784e-02,
3.2591e-01, -1.4907e+00, 1.2319e+00, 3.4916e-01, -2.2137e-02,
-2.6655e-04, 0.0000e+00, 8.2676e-02, -2.1179e-01, 5.0892e-02,
1.1412e-01, 1.3875e+00, 2.2890e+00, -2.5406e-01, 2.8060e+01,
-3.3062e-01, 2.8807e+01, -2.7523e-01, 3.3688e+01, -2.0896e-01,
3.3051e+01, -8.8114e-04, 6.7067e-04, 1.4919e-01, -2.4587e-02],
[ 1.0186e-02, -2.8051e-01, 1.0741e+01, 2.6717e+01, 1.4444e-02,
5.5334e-02, 7.8251e-01, 3.0108e+00, -1.0133e-01, 3.6721e-01,
3.3543e-01, 0.0000e+00, -2.3106e-01, -4.9731e-01, -2.1661e-01,
-4.4198e-01, -2.7478e+00, 8.0418e-02, -7.5118e-02, 2.8847e+01,
-1.4950e-01, 2.9603e+01, -9.4973e-02, 3.4477e+01, -3.0311e-02,
3.3831e+01, 1.1581e-02, 4.3092e-02, 5.1529e-03, -1.1217e-03],
[ 1.1156e-01, 1.1156e-01, 3.8119e+01, 3.8119e+01, -7.5301e-02,
2.0663e-01, -6.6783e-01, 1.7858e+00, 2.4426e-01, 1.1629e-01,
1.0294e-01, 0.0000e+00, 5.6763e-02, 1.0665e-01, -1.8539e-02,
3.1327e-01, 5.8566e-01, 1.1661e+00, 4.8594e-02, 2.4146e+01,
-4.6770e-02, 2.4119e+01, -4.0435e-02, 2.9293e+01, 3.8092e-02,
2.9315e+01, 4.2772e-03, -2.8184e-02, -1.2063e-01, -1.1147e-02],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],
[[-3.1301e-01, -6.1151e-01, 8.9435e+00, 2.3398e+01, 1.5990e-02,
7.1948e-02, -8.5906e-01, -2.8210e+00, -1.1974e-02, -2.3011e-01,
-2.0958e-01, 0.0000e+00, -4.7774e-02, 2.5396e-01, -3.1784e-02,
3.2591e-01, -1.4907e+00, 1.2319e+00, -3.4843e-01, 2.8158e+01,
-4.2567e-01, 2.8831e+01, -3.7520e-01, 3.3763e+01, -3.0863e-01,
3.3190e+01, -1.5298e-03, -2.6828e-02, -6.5327e-02, 9.3303e-03],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 2.4919e-01, -2.8348e-01, 4.2938e+00, 1.4080e+01, -6.8882e-02,
1.6863e-01, 5.4997e-01, -3.6955e+00, -1.0929e-01, 6.7108e-03,
-4.4881e-02, 0.0000e+00, -6.4197e-03, 3.7995e-02, -7.5301e-02,
2.0663e-01, -6.6783e-01, 1.7858e+00, 1.0623e-01, 2.4171e+01,
1.0990e-02, 2.4159e+01, 1.8453e-02, 2.9331e+01, 9.6896e-02,
2.9341e+01, 3.9892e-03, -1.1034e-02, 2.4403e-02, 2.4460e-03],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],
[[-2.2188e-01, -2.2188e-01, 3.7635e+01, 3.7635e+01, -1.4795e-02,
2.5464e-01, 3.0827e-01, -2.5479e+00, 2.0137e-01, -1.4957e-01,
-1.2263e-01, 0.0000e+00, 3.0785e-02, -1.8270e-01, 1.5990e-02,
7.1948e-02, -8.5906e-01, -2.8210e+00, -3.1730e-01, 2.8066e+01,
-3.9192e-01, 2.8945e+01, -3.2803e-01, 3.3721e+01, -2.6280e-01,
3.2970e+01, -4.4306e-03, -1.3863e-02, 1.3256e-01, -2.9475e-02],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 8.2153e-02, 8.2153e-02, 3.8075e+01, 3.8075e+01, -1.8012e-01,
5.1354e-01, -1.0474e+00, 3.4399e+00, 6.0403e-01, 1.7191e-01,
1.8520e-01, 0.0000e+00, 1.1124e-01, -3.4491e-01, -6.8882e-02,
1.6863e-01, 5.4997e-01, -3.6955e+00, 2.2358e-02, 2.4192e+01,
-7.2885e-02, 2.4063e+01, -7.4350e-02, 2.9239e+01, 4.1287e-03,
2.9345e+01, 6.4416e-03, -2.7741e-02, -1.0866e-01, -1.4951e-02],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]])
observation0 =
tensor([[[-1.3764e-01, -1.3764e-01, 3.7657e+01, 3.7657e+01, 5.4948e-02,
1.7581e-01, 9.1497e-01, 1.5440e+00, 2.1075e-01, 3.7513e-02,
6.2444e-02, 0.0000e+00, -6.3372e-02, -3.9705e-01, -8.4234e-03,
-2.2124e-01, -3.6421e-01, 3.5387e+00, -2.2035e-01, 2.8098e+01,
-2.9856e-01, 2.8700e+01, -2.5260e-01, 3.3676e+01, -1.8546e-01,
3.3164e+01, -1.4493e-03, 3.2164e-02, 1.8048e-01, -1.7941e-02],
[-2.2715e-01, -2.2715e-01, 3.6628e+01, 3.6628e+01, -2.1661e-01,
-4.4198e-01, -2.7478e+00, 8.0418e-02, -5.0902e-01, -1.8981e-01,
-1.0664e-01, 0.0000e+00, 1.0298e-01, 9.4488e-02, -1.1363e-01,
-3.4749e-01, 1.5066e+00, 3.9758e+00, -3.0099e-01, 2.9054e+01,
-3.7839e-01, 2.9463e+01, -3.4607e-01, 3.4536e+01, -2.7987e-01,
3.4188e+01, 1.5170e-02, 4.4323e-02, -9.8671e-02, 7.0515e-03],
[ 1.0133e-01, 1.0133e-01, 3.8128e+01, 3.8128e+01, 2.2079e-02,
-6.3589e-02, -3.2411e-01, 2.6529e+00, -6.5885e-03, 1.4360e-01,
1.2226e-01, 0.0000e+00, -1.0792e-01, 3.4128e-01, -8.5840e-02,
2.7769e-01, -3.4136e-02, -1.0512e+00, 3.3123e-02, 2.4070e+01,
-6.2227e-02, 2.4165e+01, -4.6598e-02, 2.9324e+01, 3.1937e-02,
2.9247e+01, 8.6151e-03, -3.3164e-02, -8.1405e-02, -4.3179e-03],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],
[[-1.6491e-01, -1.6491e-01, 3.7640e+01, 3.7640e+01, -3.1784e-02,
3.2591e-01, -1.4907e+00, 1.2319e+00, 3.4916e-01, -2.2137e-02,
-2.6655e-04, 0.0000e+00, 8.2676e-02, -2.1179e-01, 5.0892e-02,
1.1412e-01, 1.3875e+00, 2.2890e+00, -2.5406e-01, 2.8060e+01,
-3.3062e-01, 2.8807e+01, -2.7523e-01, 3.3688e+01, -2.0896e-01,
3.3051e+01, -8.8114e-04, 6.7067e-04, 1.4919e-01, -2.4587e-02],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 1.1156e-01, 1.1156e-01, 3.8119e+01, 3.8119e+01, -7.5301e-02,
2.0663e-01, -6.6783e-01, 1.7858e+00, 2.4426e-01, 1.1629e-01,
1.0294e-01, 0.0000e+00, 5.6763e-02, 1.0665e-01, -1.8539e-02,
3.1327e-01, 5.8566e-01, 1.1661e+00, 4.8594e-02, 2.4146e+01,
-4.6770e-02, 2.4119e+01, -4.0435e-02, 2.9293e+01, 3.8092e-02,
2.9315e+01, 4.2772e-03, -2.8184e-02, -1.2063e-01, -1.1147e-02],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]])
observation1 =
tensor([[[-2.4649e-01, -5.5179e-01, 8.8924e+00, 2.3261e+01, 5.0892e-02,
1.1412e-01, 1.3875e+00, 2.2890e+00, -1.6297e-01, -1.0621e-01,
-9.4548e-02, 0.0000e+00, 4.0567e-03, 6.1696e-02, 5.4948e-02,
1.7581e-01, 9.1497e-01, 1.5440e+00, -2.8891e-01, 2.8142e+01,
-3.6700e-01, 2.8744e+01, -3.2110e-01, 3.3719e+01, -2.5405e-01,
3.3207e+01, -2.4874e-03, -1.0706e-02, -9.8496e-02, 1.0876e-02],
[-6.5071e-02, -3.5040e-01, 1.0754e+01, 2.6788e+01, -1.1363e-01,
-3.4749e-01, 1.5066e+00, 3.9758e+00, 2.7810e-01, 2.5307e-01,
2.2713e-01, 0.0000e+00, 1.2607e-01, -4.8451e-02, 1.2437e-02,
-3.9594e-01, 1.9770e+00, -1.2419e+00, -1.4203e-01, 2.8831e+01,
-2.1622e-01, 2.9604e+01, -1.6057e-01, 3.4465e+01, -9.6004e-02,
3.3804e+01, 4.5614e-03, 3.1811e-02, -3.6468e-02, 7.7028e-03],
[ 1.8373e-01, -3.3716e-01, 4.2431e+00, 1.4035e+01, -1.8539e-02,
3.1327e-01, 5.8566e-01, 1.1661e+00, -3.3689e-01, 1.0228e-01,
6.3571e-02, 0.0000e+00, 4.0617e-02, -3.7686e-01, 2.2079e-02,
-6.3589e-02, -3.2411e-01, 2.6529e+00, 5.8397e-02, 2.4133e+01,
-3.6923e-02, 2.4155e+01, -2.6812e-02, 2.9325e+01, 5.1685e-02,
2.9306e+01, -1.4387e-03, 1.1748e-02, 4.6042e-02, 3.9803e-03],
[ 2.6432e-01, -3.7339e-02, 1.1069e+01, 2.7205e+01, -4.0951e-03,
-1.0512e-01, -4.5474e-01, -1.7421e+00, 1.4921e-01, -1.4148e-01,
-1.6022e-01, 0.0000e+00, -3.7240e-02, 8.6861e-03, -4.1336e-02,
-9.6437e-02, -4.6129e-01, -2.1930e-01, 1.2671e-01, 2.0225e+01,
1.9799e-02, 2.0906e+01, 8.9292e-02, 2.5825e+01, 1.7677e-01,
2.5276e+01, 1.8265e-03, 5.0008e-03, 1.3965e-02, -2.2795e-03]],
[[-3.1301e-01, -6.1151e-01, 8.9435e+00, 2.3398e+01, 1.5990e-02,
7.1948e-02, -8.5906e-01, -2.8210e+00, -1.1974e-02, -2.3011e-01,
-2.0958e-01, 0.0000e+00, -4.7774e-02, 2.5396e-01, -3.1784e-02,
3.2591e-01, -1.4907e+00, 1.2319e+00, -3.4843e-01, 2.8158e+01,
-4.2567e-01, 2.8831e+01, -3.7520e-01, 3.3763e+01, -3.0863e-01,
3.3190e+01, -1.5298e-03, -2.6828e-02, -6.5327e-02, 9.3303e-03],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 2.4919e-01, -2.8348e-01, 4.2938e+00, 1.4080e+01, -6.8882e-02,
1.6863e-01, 5.4997e-01, -3.6955e+00, -1.0929e-01, 6.7108e-03,
-4.4881e-02, 0.0000e+00, -6.4197e-03, 3.7995e-02, -7.5301e-02,
2.0663e-01, -6.6783e-01, 1.7858e+00, 1.0623e-01, 2.4171e+01,
1.0990e-02, 2.4159e+01, 1.8453e-02, 2.9331e+01, 9.6896e-02,
2.9341e+01, 3.9892e-03, -1.1034e-02, 2.4403e-02, 2.4460e-03],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]])
The observations are for agent0 and agent1 in the env and I put zero for done steps. It is also possible to apply done mask to observation first and then just work with a mask with 0 and 1 values I think.
Here is a naive solution:
t, b, f = observation.shape
observation0 = torch.zeros([int(t/2), b, f])
observation1 = torch.zeros_like(observation0)
for c in range(b): #for loop on B dim
cnt0 = 0
cnt1 = 0
for r in range(1, t):
if agent_index[r, c] == 0:
observation0[cnt0, c, :] = samples.env.observation[r, c, :]
cnt0 += 1
elif agent_index[r, c] == 1:
observation1[cnt1, c, :] = samples.env.observation[r, c, :]
cnt1 += 1
One point is that if I do not use -1 value in b tensor, number of 1s and 0s in one column will not be equal.
And if it helps, I’m using rlpyt for RL part.
Thank you