Selective gradient in large array

I have an array that has a few ones and a lot of zeros :
To be complete, it is the adjacency matrix of the pixels of a n*n image (here n=4), so it is n^2*n^2.
So there are total 4*n*(n-1) non-zero values in it, but only 2*n*(n-1) independent values because the matrix is symmetrical, and each diagonal has n*(n-1) values which correspond to the vertical or horizontal adjacencies.

0  1  0  0  1  0  0  0  0  0  0  0  0  0  0  0
1  0  1  0  0  1  0  0  0  0  0  0  0  0  0  0
0  1  0  1  0  0  1  0  0  0  0  0  0  0  0  0
0  0  1  0  0  0  0  1  0  0  0  0  0  0  0  0
1  0  0  0  0  1  0  0  1  0  0  0  0  0  0  0
0  1  0  0  1  0  1  0  0  1  0  0  0  0  0  0
0  0  1  0  0  1  0  1  0  0  1  0  0  0  0  0
0  0  0  1  0  0  1  0  0  0  0  1  0  0  0  0
0  0  0  0  1  0  0  0  0  1  0  0  1  0  0  0
0  0  0  0  0  1  0  0  1  0  1  0  0  1  0  0
0  0  0  0  0  0  1  0  0  1  0  1  0  0  1  0
0  0  0  0  0  0  0  1  0  0  1  0  0  0  0  1
0  0  0  0  0  0  0  0  1  0  0  0  0  1  0  0
0  0  0  0  0  0  0  0  0  1  0  0  1  0  1  0
0  0  0  0  0  0  0  0  0  0  1  0  0  1  0  1
0  0  0  0  0  0  0  0  0  0  0  1  0  0  1  0 

I use it to compute things, and I have a final scalar loss.
I want to compute the gradient of the loss w.r.t each element that has a 1 in the matrix, to optimize them.
I thought about two ways to do that, but one of them is not doable (afaik), and I’m afraid the second one is inefficient.

First method :

  • Declare a tensor of size 2n(n-1) as a Variable with requires_grad=True
  • Insert those values in a n^2*n^2 (this is the part I don’t know how to do)
  • Do the computation and get the loss

In that method, I will only get the gradients of the values I want, and no others, so no useless computation.
The problem is, I don’t know if it’s possible to insert the values in a new tensor, and keep the gradient connection. Is it possible ?
The closest I have seen is torch.index_select(), but I can only select on 1 dimension, which would make the process rather uneasy, seeing that I want to select the diagonals.

Second method :

  • Declare a tensor of size n^2*n^2 as a Variable with requires_grad=True
  • Do the computation and get the loss

In this method, it would be rather inefficient to compute the gradient w.r.t all the zero elements.
So is there a way, to do a “selective requires_grad” with for example a mask, that would prevent computing all gradients (especially useful if n increases)?
I saw that I can pass the mask (which is actually also the matrix above) directly to backward (loss.backward(mask)), which will result in a gradient with positive values everywhere I want and zeros everywhere else, but did it do it efficiently ? So my question is, when one of the elements in the tensor passed to backward() is zero, does PyTorch bother to calculate it for that element since it’s going to end up being zero ?

Finally, I have seen the torch.sparse API in the doc, and I think I could use it, but I am unsure whether I can define a sparse tensor as a leaf variable, and if the gradient propagates well through to_dense().
But mostly, the API says “This API is currently experimental and may change in the near future.”, so I don’t really want to use it if it’s going to change…

I’d appreciate some help on these matters :slight_smile:

Hi, Matthieu Heitz.

Did you solve this problem? Actually, I’m also finding the solution for the similar problem.

Can we do selective gradient calculation in backward() function before updating?

Could you please let me know if you found the answer?

Best

Hi !

No I haven’t solved this, I am still using the big matrix as the leaf, and I hope that the gradient is not computed for the zero values.
Right now, I am more interested in if the sparse API is going to be supplemented with more functions and autograd. That would solve the problem probably.
I could try to implement the C functions myself (like sparse matmul, and its gradient) and wrap them in Python like it is said to be possible here : https://github.com/pytorch/extension-ffi
I don’t know how long that is going to take me though…

I have finally succeeded in doing the step 2 of the first method. It wasn’t that hard, but it is a solution that only works for me, because I only fill diagonals in the matrix:
To insert values of an array of 2*n*(n-1) in the matrix, I do it this way :

>>> n = 4
>>> W = torch.rand(2,n*(n-1), requires_grad=True)
tensor([[ 0.4963,  0.7682,  0.0885,  0.1320,  0.3074,  0.6341,  0.4901, 0.8964,  0.4556,  0.6323,  0.3489,  0.4017],
        [ 0.0223,  0.1689,  0.2939,  0.5185,  0.6977,  0.8000,  0.1610, 0.2823,  0.6816,  0.9152,  0.3971,  0.8742]])
>>> D = W
>>> D1 = torch.nn.ConstantPad1d((0, 1), 0)(D[0].view(n, n - 1)).view(-1)[:-1]
>>> Dn = D[1]
>>> data = (Dn, D1, D1, Dn)
>>> offsets = torch.LongTensor([-n, -1, 1, n])
>>> M = [torch.diag(datum, offset) for datum, offset in zip(data, offsets)]
>>> A = torch.stack(M).sum(0)
# Calling .backward() on for example Z = torch.sum(A) works !
>>> A
tensor([[ 0.0000,  0.4963,  0.0000,  0.0000,  0.0223,  0.0000,  0.0000, 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000],
        [ 0.4963,  0.0000,  0.7682,  0.0000,  0.0000,  0.1689,  0.0000, 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000],
        [ 0.0000,  0.7682,  0.0000,  0.0885,  0.0000,  0.0000,  0.2939, 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0885,  0.0000,  0.0000,  0.0000,  0.0000, 0.5185,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000],
        [ 0.0223,  0.0000,  0.0000,  0.0000,  0.0000,  0.1320,  0.0000, 0.0000,  0.6977,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000],
        [ 0.0000,  0.1689,  0.0000,  0.0000,  0.1320,  0.0000,  0.3074, 0.0000,  0.0000,  0.8000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.2939,  0.0000,  0.0000,  0.3074,  0.0000, 0.6341,  0.0000,  0.0000,  0.1610,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.5185,  0.0000,  0.0000,  0.6341, 0.0000,  0.0000,  0.0000,  0.0000,  0.2823,  0.0000,  0.0000, 0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.6977,  0.0000,  0.0000, 0.0000,  0.0000,  0.4901,  0.0000,  0.0000,  0.6816,  0.0000, 0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.8000,  0.0000, 0.0000,  0.4901,  0.0000,  0.8964,  0.0000,  0.0000,  0.9152, 0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.1610, 0.0000,  0.0000,  0.8964,  0.0000,  0.4556,  0.0000,  0.0000, 0.3971,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.2823,  0.0000,  0.0000,  0.4556,  0.0000,  0.0000,  0.0000, 0.0000,  0.8742],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.6816,  0.0000,  0.0000,  0.0000,  0.0000,  0.6323, 0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000,  0.9152,  0.0000,  0.0000,  0.6323,  0.0000, 0.3489,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000,  0.0000,  0.3971,  0.0000,  0.0000,  0.3489, 0.0000,  0.4017],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000,  0.0000,  0.0000,  0.8742,  0.0000,  0.0000, 0.4017,  0.0000]])

Further explanations:

# Pad 1 zero every n-1 value, and remove the last zero padded.
>>> D1 = torch.nn.ConstantPad1d((0, 1), 0)(D[0].view(n, n-1)).view(-1)[:-1]
# There is the right number of values in D to fill the nth diagonal and have a resulting N*N matrix (N=N**2)
>>> Dn = D[1]
# List of diagonal data
>>> data = (Dn, D1, D1, Dn)
# List of offsets
>>> offsets = torch.LongTensor([-n, -1, 1, n])
# Creates a list of diagonal matrices for each diagonal in data.
>>> M = [torch.diag(datum, offset) for datum, offset in zip(data, offsets)]
# Sum all these matrices to get the final matrix
>>> A = torch.stack(M).sum(0)

I see a lot of functions in Pytorch to index in big arrays to get smaller arrays (torch.index_select(), torch.masked_select(), torch.take(), torch.where()), but very few (i.e. torch.diag()) to do the opposite, which is, placing values of a small arrays in a bigger one.

Does anyone know a reason for that ?

Thanks.