Masking a range of tensors

I am wondering how can I make use of torch.scatter_() or any other in-built masking function for this one-hot masking task-

I have two tensors-

X = [batch, 100] and label = [batch]

num_classes = 10

so each label has 10 tensors out of those 100 tensors in ‘X’.
For instance X of shape [1x100],

X = ([ 0.0468, -1.7434, -1.0217, -0.0724, -0.5169, -1.7318, -0.1207, -0.8377,
        -0.8055,  0.7438,  0.1139,  1.2162, -1.7950,  1.7416, -1.2031, -1.4833,
        -0.5454,  0.2466, -1.2303, -0.4257,  0.9873, -1.5905, -1.3950,  0.4013,
        -1.0523,  1.4450,  0.6574,  1.5239, -0.3503, -0.1114,  1.8192, -1.7425,
         0.4678,  0.4074,  1.7606, -1.0502,  0.0724,  0.1721,  0.1108,  0.4453,
         0.2278, -1.5352, -0.1232,  1.1052,  0.2496,  1.2898, -0.4167, -0.8211,
         0.2340, -0.3829, -0.1328,  0.1033,  2.8693, -0.8802, -0.0433,  0.5335,
         0.0662,  0.4250,  0.2353, -0.1590,  0.0865,  0.6519, -0.2242,  1.5300,
         1.7021, -0.9451,  0.5845, -0.7309,  0.7124,  0.6544, -1.4426, -0.1859,
        -1.5313, -1.5391, -0.2138, -1.0203,  0.6678,  1.3445, -1.3453,  0.5222,
         0.9510,  0.0969, -0.5437, -0.2727, -0.6090, -2.9624,  0.4578,  0.5257,
        -0.2866,  0.0818, -1.2454,  1.6511,  0.1634,  1.3720, -0.4222,  0.5347,
         0.3586, -0.3506,  2.6866,  0.5084])

label = [3]

I would like to do one-hot masking of “1” to tensors 30-40 and rest all the tensors as “0” on the tensor ‘X’.

Could you explain a bit, how your current example would work, i.e. how would the label tensor with the value of 3 mask the entries in X at index 30 to 40?

@ptrblck I would like reserve each equal slots of tensors to each label. Since in this case there are 10 labels and 100 tensors, I am having 10 slots for each label. For instance if there are 500 tensors then each label would have 50 slots marked as 1 and rest all marked 0.

Would the shape of label be in this case [1, 10] and if so what does the [3] refer to?

@ptrblck I am sorry If I did not understand your question.
I chose [3] as a label out of 10 labels.

Usually, I will have labels shape as [Batch, label] but yes in this case label=[3] (shape -> [1,10] ).
The label is just an indicator for preforming one-hot masking on ‘X’.

I’m sorry, but I still think I misunderstand the use case.
If label contains a single integer (3), its shape would be [1] or [1, 1]. Are you expanding it to [1, 10] or how are these 10 values defined?

Could you post a small example with input tensors and the desired output?
This could maybe make it easier to write the desired code for the wanted operation.

@ptrblck
Sure,
For example -

num_class = 10
weight = 100
batch = 5

X = torch.randn(batch, weight)
label = torch.randint(0,10,(batch,))

(Shape of X is [ 3 x 100 ]  and labels are {9, 1, 9})

X = tensor([[-6.0834e-01, -5.9079e-01, -3.4196e-01,  5.7168e-01, -4.3331e-01,
          1.6516e+00,  4.6272e-01,  4.5185e-01, -1.4575e+00,  4.0765e-02,
          3.1781e-01, -1.6579e+00,  1.7221e+00,  7.1746e-01, -5.3044e-01,
         -1.0118e+00, -3.5190e-01, -1.9081e+00,  1.5027e-01,  1.1446e-01,
         -1.4772e+00, -2.5868e-01, -1.4384e+00,  6.8575e-01,  2.4126e-01,
          3.2693e-01, -4.2781e-01,  2.1950e-03, -1.3695e+00,  2.1803e+00,
          6.7851e-01, -2.4332e-01,  4.2386e-02, -1.1963e+00, -1.7549e+00,
         -4.3406e-01,  1.6647e+00, -1.2375e+00,  2.0899e+00,  2.0276e+00,
          2.8668e-01,  3.6571e-01, -1.6306e-01, -4.6049e-01, -8.9992e-03,
         -6.0769e-01,  1.3757e+00, -1.1240e+00, -1.6341e-01,  1.4133e+00,
         -6.3187e-01,  2.1754e+00,  2.0319e-01, -2.8198e-02,  7.5469e-02,
         -5.0488e-01, -2.0968e+00, -2.7886e-01, -8.6695e-01, -6.3191e-01,
          9.1306e-01,  8.0160e-01, -8.4536e-01, -1.2476e-01, -4.7699e-01,
          1.5153e+00,  1.2025e+00, -3.8749e-01,  5.8015e-01, -1.2572e+00,
          7.3191e-01, -1.2494e-01, -1.3664e+00,  1.6239e+00,  2.4665e-03,
          5.3352e-02,  4.3461e-01, -6.1652e-01,  1.6548e+00,  3.3952e-02,
         -8.0151e-01,  2.1024e-02, -8.1717e-01,  3.8690e-01,  8.2205e-01,
          1.7624e+00,  2.6072e-01, -5.7074e-01,  9.8895e-01,  4.2740e-01,
          1.1893e+00, -4.9188e-02, -1.4423e+00, -7.4522e-01,  2.7951e-01,
         -1.9912e-01, -1.2297e-02, -7.6552e-02, -1.7420e-01, -1.4726e+00],
        [ 1.6538e+00,  2.7518e-01,  2.5307e-01, -5.1267e-01,  6.1062e-01,
          7.4058e-01, -8.4256e-02,  1.4839e+00,  7.0765e-01,  1.0990e+00,
          2.5285e+00, -5.6504e-01, -2.5689e-01,  4.5166e-01, -5.7540e-01,
         -1.1508e-03,  5.6673e-01, -8.1504e-01,  1.2127e+00,  6.2682e-01,
         -5.1741e-01,  2.1806e+00,  2.6361e-01, -1.5621e+00,  1.3641e-01,
         -8.1526e-01,  4.4094e-01,  8.1348e-01, -9.4383e-01, -4.2741e-01,
         -8.4335e-02, -2.7072e+00, -2.3655e-01, -7.3133e-01,  1.2045e+00,
         -4.7432e-01, -8.1001e-01, -2.8357e-01, -4.3105e-01, -3.3333e-01,
         -1.7669e-01,  6.2751e-01, -1.4288e+00,  1.1203e+00, -6.9312e-01,
          1.0733e+00,  4.6843e-01,  8.8390e-01,  8.0696e-01,  4.8746e-01,
         -1.4401e+00, -5.8271e-01,  1.1030e+00, -6.9418e-01,  1.1979e-01,
         -3.8968e-01, -5.9038e-01,  6.4342e-01, -4.2759e-01,  5.5010e-01,
          2.5181e+00, -1.7223e-01,  1.0016e+00,  1.5534e+00, -1.0256e+00,
         -7.9901e-01,  1.0004e+00, -2.9892e-01, -1.5527e+00,  1.0578e+00,
         -3.8758e-01, -1.0519e+00,  1.0527e+00, -5.7627e-02, -4.3340e-01,
          4.2683e-02, -1.4423e+00, -4.0123e-02, -5.9378e-01,  3.1165e-01,
         -4.8258e-01,  2.3515e+00, -8.7709e-01, -1.8835e+00, -4.4491e-02,
          3.6607e-01, -1.6562e-01,  2.4096e+00,  4.4988e-01, -9.3653e-01,
         -1.4464e-02,  3.0374e-01, -1.4728e+00,  3.9607e-01, -9.8894e-01,
          1.7206e-01, -6.0426e-01, -9.6410e-01, -5.0377e-01,  7.9586e-01],
        [-4.6847e-01,  6.3710e-02, -5.2126e-01,  1.2289e+00,  1.6010e+00,
         -1.0884e-01,  3.0637e-01,  2.0336e-01, -8.9281e-01, -1.1062e+00,
         -1.9814e+00,  7.3068e-01, -5.0471e-01, -1.1721e+00, -3.8609e-01,
          1.0695e+00,  1.4897e-01,  1.7660e+00,  8.4988e-01, -1.7742e+00,
          4.5779e-01,  1.2820e+00,  1.6286e+00, -3.5526e-01,  8.2296e-01,
          1.2147e+00, -1.3172e+00, -1.4643e-01, -1.0830e+00, -1.3402e+00,
         -1.1356e+00,  1.0653e+00, -1.0476e+00,  9.2002e-02,  1.0856e+00,
         -1.1596e+00,  2.9322e-01, -1.2866e+00,  8.6806e-01, -1.5686e+00,
          8.4743e-02,  1.8034e-01,  1.2475e+00,  6.7162e-01, -8.6842e-01,
         -1.1382e-01, -3.7422e-01,  7.6654e-01, -3.7366e-01, -9.6299e-01,
          9.0468e-01, -9.0498e-01, -3.2067e-02, -7.6294e-01,  1.1564e+00,
          1.1394e-01,  1.2499e+00, -3.5187e-01, -2.2726e-01, -7.3394e-01,
         -7.5910e-01, -9.5211e-01, -8.3890e-01,  5.3075e-02, -6.5009e-01,
          7.3806e-01, -1.2851e+00, -1.0579e+00,  1.6116e+00, -3.8414e-01,
         -4.4634e-01, -1.0332e+00,  1.3748e+00, -4.8612e-01, -1.1840e+00,
         -9.6800e-01, -3.9215e-01, -2.0809e-01, -9.6478e-01,  5.9514e-01,
         -2.0644e-01, -1.0553e+00, -4.0866e-01, -1.6201e+00, -2.1583e-01,
         -1.5920e-01,  1.3905e+00, -1.4117e+00,  7.6843e-01,  8.0739e-02,
         -8.3867e-01, -1.2902e-01, -3.0284e-01,  5.3228e-01, -4.8684e-01,
          5.3436e-01, -1.2595e+00, -9.4177e-02,  1.0984e+00, -9.2556e-01]])

desired X = tensor([[-0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,
          0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,
          0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,
         -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
          0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
          0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,
          0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,
         -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.1893e+00, -4.9188e-02, -1.4423e+00, -7.4522e-01,  2.7951e-01,
         -1.9912e-01, -1.2297e-02, -7.6552e-02, -1.7420e-01, -1.4726e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,
          0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          2.5285e+00, -5.6504e-01, -2.5689e-01,  4.5166e-01, -5.7540e-01,
         -1.1508e-03,  5.6673e-01, -8.1504e-01,  1.2127e+00,  6.2682e-01,
         -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,
         -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,
          0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,
         -0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,
          0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,
         -0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
          0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,
         -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,
          0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],
        [-0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,
         -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,
         -1.9814e+00,  7.3068e-01, -5.0471e-01, -1.1721e+00, -3.8609e-01,
          1.0695e+00,  1.4897e-01,  1.7660e+00,  8.4988e-01, -1.7742e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,
          0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,
         -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,
          0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,
          0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,
          0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00]])

Sorry, I still don’t understand how X and label can be used to yield the desired output.