@ptrblck
Sure,
For example -
num_class = 10
weight = 100
batch = 5
X = torch.randn(batch, weight)
label = torch.randint(0,10,(batch,))
(Shape of X is [ 3 x 100 ] and labels are {9, 1, 9})
X = tensor([[-6.0834e-01, -5.9079e-01, -3.4196e-01, 5.7168e-01, -4.3331e-01,
1.6516e+00, 4.6272e-01, 4.5185e-01, -1.4575e+00, 4.0765e-02,
3.1781e-01, -1.6579e+00, 1.7221e+00, 7.1746e-01, -5.3044e-01,
-1.0118e+00, -3.5190e-01, -1.9081e+00, 1.5027e-01, 1.1446e-01,
-1.4772e+00, -2.5868e-01, -1.4384e+00, 6.8575e-01, 2.4126e-01,
3.2693e-01, -4.2781e-01, 2.1950e-03, -1.3695e+00, 2.1803e+00,
6.7851e-01, -2.4332e-01, 4.2386e-02, -1.1963e+00, -1.7549e+00,
-4.3406e-01, 1.6647e+00, -1.2375e+00, 2.0899e+00, 2.0276e+00,
2.8668e-01, 3.6571e-01, -1.6306e-01, -4.6049e-01, -8.9992e-03,
-6.0769e-01, 1.3757e+00, -1.1240e+00, -1.6341e-01, 1.4133e+00,
-6.3187e-01, 2.1754e+00, 2.0319e-01, -2.8198e-02, 7.5469e-02,
-5.0488e-01, -2.0968e+00, -2.7886e-01, -8.6695e-01, -6.3191e-01,
9.1306e-01, 8.0160e-01, -8.4536e-01, -1.2476e-01, -4.7699e-01,
1.5153e+00, 1.2025e+00, -3.8749e-01, 5.8015e-01, -1.2572e+00,
7.3191e-01, -1.2494e-01, -1.3664e+00, 1.6239e+00, 2.4665e-03,
5.3352e-02, 4.3461e-01, -6.1652e-01, 1.6548e+00, 3.3952e-02,
-8.0151e-01, 2.1024e-02, -8.1717e-01, 3.8690e-01, 8.2205e-01,
1.7624e+00, 2.6072e-01, -5.7074e-01, 9.8895e-01, 4.2740e-01,
1.1893e+00, -4.9188e-02, -1.4423e+00, -7.4522e-01, 2.7951e-01,
-1.9912e-01, -1.2297e-02, -7.6552e-02, -1.7420e-01, -1.4726e+00],
[ 1.6538e+00, 2.7518e-01, 2.5307e-01, -5.1267e-01, 6.1062e-01,
7.4058e-01, -8.4256e-02, 1.4839e+00, 7.0765e-01, 1.0990e+00,
2.5285e+00, -5.6504e-01, -2.5689e-01, 4.5166e-01, -5.7540e-01,
-1.1508e-03, 5.6673e-01, -8.1504e-01, 1.2127e+00, 6.2682e-01,
-5.1741e-01, 2.1806e+00, 2.6361e-01, -1.5621e+00, 1.3641e-01,
-8.1526e-01, 4.4094e-01, 8.1348e-01, -9.4383e-01, -4.2741e-01,
-8.4335e-02, -2.7072e+00, -2.3655e-01, -7.3133e-01, 1.2045e+00,
-4.7432e-01, -8.1001e-01, -2.8357e-01, -4.3105e-01, -3.3333e-01,
-1.7669e-01, 6.2751e-01, -1.4288e+00, 1.1203e+00, -6.9312e-01,
1.0733e+00, 4.6843e-01, 8.8390e-01, 8.0696e-01, 4.8746e-01,
-1.4401e+00, -5.8271e-01, 1.1030e+00, -6.9418e-01, 1.1979e-01,
-3.8968e-01, -5.9038e-01, 6.4342e-01, -4.2759e-01, 5.5010e-01,
2.5181e+00, -1.7223e-01, 1.0016e+00, 1.5534e+00, -1.0256e+00,
-7.9901e-01, 1.0004e+00, -2.9892e-01, -1.5527e+00, 1.0578e+00,
-3.8758e-01, -1.0519e+00, 1.0527e+00, -5.7627e-02, -4.3340e-01,
4.2683e-02, -1.4423e+00, -4.0123e-02, -5.9378e-01, 3.1165e-01,
-4.8258e-01, 2.3515e+00, -8.7709e-01, -1.8835e+00, -4.4491e-02,
3.6607e-01, -1.6562e-01, 2.4096e+00, 4.4988e-01, -9.3653e-01,
-1.4464e-02, 3.0374e-01, -1.4728e+00, 3.9607e-01, -9.8894e-01,
1.7206e-01, -6.0426e-01, -9.6410e-01, -5.0377e-01, 7.9586e-01],
[-4.6847e-01, 6.3710e-02, -5.2126e-01, 1.2289e+00, 1.6010e+00,
-1.0884e-01, 3.0637e-01, 2.0336e-01, -8.9281e-01, -1.1062e+00,
-1.9814e+00, 7.3068e-01, -5.0471e-01, -1.1721e+00, -3.8609e-01,
1.0695e+00, 1.4897e-01, 1.7660e+00, 8.4988e-01, -1.7742e+00,
4.5779e-01, 1.2820e+00, 1.6286e+00, -3.5526e-01, 8.2296e-01,
1.2147e+00, -1.3172e+00, -1.4643e-01, -1.0830e+00, -1.3402e+00,
-1.1356e+00, 1.0653e+00, -1.0476e+00, 9.2002e-02, 1.0856e+00,
-1.1596e+00, 2.9322e-01, -1.2866e+00, 8.6806e-01, -1.5686e+00,
8.4743e-02, 1.8034e-01, 1.2475e+00, 6.7162e-01, -8.6842e-01,
-1.1382e-01, -3.7422e-01, 7.6654e-01, -3.7366e-01, -9.6299e-01,
9.0468e-01, -9.0498e-01, -3.2067e-02, -7.6294e-01, 1.1564e+00,
1.1394e-01, 1.2499e+00, -3.5187e-01, -2.2726e-01, -7.3394e-01,
-7.5910e-01, -9.5211e-01, -8.3890e-01, 5.3075e-02, -6.5009e-01,
7.3806e-01, -1.2851e+00, -1.0579e+00, 1.6116e+00, -3.8414e-01,
-4.4634e-01, -1.0332e+00, 1.3748e+00, -4.8612e-01, -1.1840e+00,
-9.6800e-01, -3.9215e-01, -2.0809e-01, -9.6478e-01, 5.9514e-01,
-2.0644e-01, -1.0553e+00, -4.0866e-01, -1.6201e+00, -2.1583e-01,
-1.5920e-01, 1.3905e+00, -1.4117e+00, 7.6843e-01, 8.0739e-02,
-8.3867e-01, -1.2902e-01, -3.0284e-01, 5.3228e-01, -4.8684e-01,
5.3436e-01, -1.2595e+00, -9.4177e-02, 1.0984e+00, -9.2556e-01]])
desired X = tensor([[-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00,
0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00,
-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00,
-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00,
0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00,
-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00,
-0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00,
-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00,
0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00,
-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00,
1.1893e+00, -4.9188e-02, -1.4423e+00, -7.4522e-01, 2.7951e-01,
-1.9912e-01, -1.2297e-02, -7.6552e-02, -1.7420e-01, -1.4726e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00,
0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
2.5285e+00, -5.6504e-01, -2.5689e-01, 4.5166e-01, -5.7540e-01,
-1.1508e-03, 5.6673e-01, -8.1504e-01, 1.2127e+00, 6.2682e-01,
-0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00,
-0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00,
-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00,
-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
-0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00,
-0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00,
0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00,
-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00,
-0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00,
0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00,
-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00,
-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00,
0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],
[-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00,
-0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00,
-1.9814e+00, 7.3068e-01, -5.0471e-01, -1.1721e+00, -3.8609e-01,
1.0695e+00, 1.4897e-01, 1.7660e+00, 8.4988e-01, -1.7742e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00,
0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00,
-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00,
-0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00,
0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00,
0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00,
-0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00,
-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00,
-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00,
-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00,
0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00]])