Help with indexing

Hi Pytorch community,

I am experiencing weird behavior while dealing with the following scenario. tne is the batch size, while N is the total number of possible locations for each example. Each example has exactly 24 scalar outputs, where their locations are stored in dof tensor (size of dof is (tne X 24)). Fint_e has a size of tneX24 (i.e., the 24 outputs for each example). I am trying to construct a large tensor, which has a size of tne X N. When I do the following, it fills in the wrong manner.

Fint_MAT        = torch.zeros((tne,N))
Fint_MAT[:,dof]  = Fint_e

I include a reproducible example to give a better illustration of the issue.

tne = 3
N   = 48
Fint_MAT = torch.zeros((tne,N))
Fint_e   = torch.randn((tne, 24))
v1 = torch.arange(24).unsqueeze(0)
v2 = torch.arange(12, 36).unsqueeze(0)
v3 = torch.arange(24, 48).unsqueeze(0)
dof      = torch.cat((v1,v2,v3), axis=0).long()
Fint_MAT[:,dof]  = Fint_e

Each row would have 24 nonzeros and 24 zeros, while they are different from one row to another. The columns of the nonzeros are stated in the tensor dof for each corresponding row. However, what I get is that all the 48 entries are nonzero.

I’m not sure, if I understand the use case correctly, but would scatter_ work?

Fint_MAT.scatter_(1, dof, Fint_e)
print(Fint_MAT)
> tensor([[-1.8981e-01,  1.4342e+00,  3.2527e-01, -4.6723e-01,  5.3798e-01,
          2.2025e-01,  7.3162e-01,  8.4189e-01,  3.5707e+00, -1.1000e+00,
         -1.3492e+00, -1.3385e-01,  4.7299e-01,  2.2391e+00, -1.1256e+00,
         -5.3990e-01,  3.0493e+00,  1.9767e+00,  3.1951e-01, -1.1720e+00,
         -7.3205e-01, -5.1244e-01,  1.3242e+00,  4.6621e-02,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  1.4342e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00, -5.6937e-01, -6.7827e-01, -1.1316e+00,
         -9.1802e-01, -9.4490e-01, -1.7924e+00,  9.1504e-01, -9.5629e-01,
          1.2545e+00, -1.2473e+00, -5.6272e-02,  1.7227e+00, -4.9540e-01,
          2.1001e-01,  2.2813e-01, -8.4102e-01,  5.3948e-01, -7.1214e-01,
         -9.2740e-01, -2.2006e-02,  5.2053e-01,  1.6844e-02, -6.8976e-01,
         -7.4224e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  3.2527e-01,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -4.4625e-01,
          9.3915e-01,  7.7942e-01,  2.2305e-01,  9.9742e-01,  8.0741e-01,
          6.6438e-01, -1.0362e+00,  1.8041e-02,  2.5564e-01,  1.3967e+00,
         -3.5279e-01, -1.2420e+00,  5.4921e-01, -2.4048e-01, -1.9073e-05,
         -9.6158e-01, -7.9354e-01, -1.3437e+00, -1.7198e+00, -1.3244e+00,
         -3.9174e-01, -5.5179e-01,  4.9326e-01]])

Thanks, @ptrblck for the reply. It looks like it is along the lines of what I am expecting, but it is not totally behaving as desired. For the use case mentioned, I am expecting the 24 nonzero entries to be at the beginning of the first row. For the second row, the 24 nonzero entries would be in the middle (index 12 to index 36), while for the third row, the 24 nonzero entries would be at the end of the row (index 24 to index 48).
Thanks again!

This should be the case when using scatter_, and I guess I had some intermediate results stored in Fint_MAT.
After rerunning I get these results:

tensor([[-0.8591,  0.8688,  1.2835,  0.9208,  0.1976,  1.0741, -1.4244,  0.9609,
          1.0544,  0.2955, -2.0284,  2.2574,  0.2363, -0.5887,  0.4460, -1.2576,
          0.2777, -0.0992, -1.1086, -0.0031, -0.8203,  0.4957, -0.8439, -0.8891,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000, -0.4912,  0.7667,  1.0390, -0.9783,
         -0.2817,  0.8187,  0.0521,  0.1803,  1.7951,  0.3633,  1.3415, -0.9563,
         -1.0455, -2.2024, -0.4149, -0.0370,  1.4790, -0.5650, -2.1901,  0.3896,
          0.3772,  0.1150,  1.3569,  1.5365,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.3428, -1.2165, -0.9348,  2.0201,  0.8961, -0.2702, -0.2217, -1.8465,
          0.0554, -0.3630,  2.3023,  1.0387, -0.0799, -0.9479, -0.9105,  0.4960,
         -0.7692, -0.6591, -1.1607, -2.1668, -0.1858, -0.6057,  0.5822,  0.5612]])

which seem to look as you’ve described them.

1 Like

Many thanks, @ptrblck. This solves it.