Require gradient only for some tensor elements (others should be held constant)

I like to use a tensor with only a few variable elements which are considered during the backpropagation step. Consider for example:

a = torch.ones((2, 2), requires_grad=True)
b = torch.ones((2, 2), requires_grad=True)
c = torch.ones((1, 2), requires_grad=True)

x = torch.ones((2, 1))

y = c @ (b @ (a @ x))

Now for the c tensor I want the c[0, 1] element to be constant, i.e. not being varied during the optimization procedure. One idea is that, after doing e.g. y.sum().backward(), I could zero the corresponding gradient elements:

c.grad[0, 1] = 0

Is this a correct way of dealing with the problem? Or are the other gradients (of a, b) being influenced by the value of c.grad[0, 1] during the backpropagation step? Considering the gradient calculation it seems they shouldnā€™t:

where L = y.sum(), a^i_j is the value of the j-th element in the i-th layer of the computational graph (e.g. a^1_0 == (a @ x)[0]) and w^i_{jk} is the jk element of the i-th tensor (e.g. w^2_{01} == b[0, 1]. The first term of the above product is passed on during backpropagation and since it doesnā€™t contain derivatives with respect to any weights it seems that modifying parts of the gradient of a specific tensor wouldnā€™t affect previous gradients in the graph.

I just want to double check on this and ask in general if this is the preferred way of dealing with such situations? Maybe there is a more elegant (more appropriate) way of dealing with constant tensor elements?

Any help is appreciated. Thanks.

7 Likes

Hi,
Were you able to figure out a solution of this topic? Iā€™ve been struggling with this issue on tensorflow and using some ad-hoc methods to deal with them.

1 Like

I ended up using a scalar tensor with requires_grad=True and then adding it to the overall constant tensor element like this:

x = torch.tensor(1., requires_grad=True)
y = torch.zeros((2, 2))
y[0, 1] += x

Now this causes all elements of y to be constant except for the (0, 1) element. If you want it the other way around, i.e. all but one element require grad, then you just need more tensors to add. For example letā€™s say the (0, 1) should be constant, so we can do:

x = torch.ones((2, 1), requires_grad=True)
y = torch.tensor(1., requires_grad=True)
z = torch.zeros((2, 2))
z[:, 0] += x
z[1, 1] += y

For a 2D tensor (matrix) at most four distinct tensors with requires_grad=True should suffice to make all but a single element require grad as well.

Nevertheless I think that the method of setting the .grad attribute partially to zero works as well.

Thanks for this answer. Thanks a lot. Could I bother you with one more? Given that Iā€™m going to use the opt.minimize method to minimize a loss function, in tensorflow, one needs to set a var_list argument, which tells the optimizer(SGD etc) which tensors to vary. In that case, i would just add x(and y) to that list, and leave z out, or can i just list z and it can figure out which elements are variables and which are constants ?

I can only respond from the PyTorch perspective, but here you would make the original tensors (the ones with requires_grad=True) to be the parameters of the optimization. In the end, operations like y[0, 1] += x create a new node in the computation graph, with inputs x and y, where x is variable and y is constant. Only after that node, the resulting tensor y requires grad as well, but there is no such notion for the autograd engine, especially when it comes to the fact that only one element of that tensor should be variable. I suspect the situation is the same for tensorflow.

hey - i think i have a similar situation and wanted to check my method of solving it -
say i have a matrix of weights in a neural network and i want only specific elements in the matrix to be trained using backprop, say the bottom row. so i would initialize the entire matrix as random with requires_grad=True and then multiply by a matrix of zeros with the bottom row set to ones with requires_grad=False.

will this work?

The easiest way is adding a mask constant.
Like

y = c@(b@(a@x))

You can add a constant like mask = torch.ones((2,2))
ans set mask[0,1]=0
and you can switch your formula to

y = mask.dot(c) @(b@(a@x))

so the gradient of y will not back propagate to c[0,1]

1 Like

i didnā€™t understand what you did with the @ but thatā€™s what i said no?

i want to learn a representation only on top row of a matrix, called embed:

embed = torch.randn((1,10,10),requires_grad=True)
mask =torch.zeros((1,10,10),requires_grad=False)
mask[0,0,:] = 1
embed = embed * mask

  • now build the network with embed where i want it to be.

p.s
the reason i want it still to be a matrix is because of the architecture that can only accept a matrix at that point.

Yes , you got what I meant.
Thatā€™s the easiest way to erase the gradient of your matrix.
I donā€™t know the ā€˜@ā€™ mean, I thought it might be any operator like \times \add \minus in your formula.

You can do that, but obviously it will change the results of the computation, if you set some of the elements of embed to zero already in the forward step. It depends on your use case, whether you want the masked elements to contribute to the forward pass or not. Using embed = embed * mask it will set the contributions to zero.

My original use case was different, I had non-zero matrix elements which should contribute to the final result, but not change during training. Hence I used a different approach.

hey, iā€™m back with this issue of setting only certain elements of the tensor ā€œlearnableā€ -
my net is more complicated than this example but maybe itā€™s enough to find my issue without describing the whole thing:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(2, 6, 3) #notice two channel input
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(16 * 6 * 6, 2)
        self.embed = torch.randn((1,10,10),requires_grad=True) # lets say the input is a 10x10 image, and so is the embedding before masking
        mask = torch.zeros((1,10,10),requires_grad=False)
        mask[0,:,5] = 1 # we want to mask the embedding such that only one row is actually updated during backprop
        self.embed = self.embed * mask # notice that embed requires_grad and mask doesn't require. is this how it should be done or vice versa?

    def forward(self, x):
        x = torch.cat([self.embed,x])
        out = self.conv1(x)
        out = flatten(out)
        return self.fc1(out)

should this work in theory?

in my case itā€™s slightly more complicated, because certain inputs belong to one category and i want to learn an embedding of each category, so the forward pass accepts x and category,it looks like this:

self.embed =  self.embed.repeat(category_num,1,1)
def forward(self,x, category):
     x = torch.cat([self.embed[category], x])
     out = self.conv1(x)
     out = flatten(out)
     return self.fc1(out)

iā€™m doing almost exactly that and getting a weird error that i couldnā€™t solve -

runtimeerror: trying to backward through the graph a second time, but the saved intermediate results have already been freed. specify retain_graph=true when calling backward the first time.

The reason you get this error is because you are using the self.embed = self.embed * mask version in forward (I will call this self.embed2 in the following, as if you did self.embed2 = self.embed * mask). This tensor object is however not a leaf in the computation graph; the leaf is the original self.embed = torch.randn((1,10,10),requires_grad=True) object (I will continue to call this self.embed in the following). So your computation graph goes like this: self.embed --(a)--> self.embed2 --(b)--> .... In forward you are effectively using the self.embed2 tensor, so the edge --(a)--> never gets renewed, and hence it will remain in the graph also for the second forward pass and hence the error that you attempt to backprop another time through (that part of) the graph. Also note that if this worked (you could make it work by using backward(retain_graph=True)), your updates to the original self.embed tensor would never be effective, since youā€™re stuck with the self.embed2 version (which contains the original values of self.embed).

So the solution is to recompute self.embed2 on every forward pass. Basically you need to shift the relevant lines to forward:

def __init__(self):
    ...
    self.embed = torch.randn((1,10,10),requires_grad=True)
    # Move the `mask`ing part to `forward`.

def forward(self):
    mask = torch.zeros((1,10,10),requires_grad=False)
    mask[0,:,5] = 1
    embed = self.embed * mask  # don't overwrite `self.embed`
    x = torch.cat([embed,x])  # use the newly created `embed` tensor
    ...
2 Likes

Thatā€™s a great explanation, thanks!

What should I do in case my masking tensor is actually a very large tensor with a binary pattern that I load in the argument section of the init function?

Can I somehow project a tensor I load when initiating the neural net onto forward pass?

Sure, actually you can instantiate the mask already in __init__, i.e.:

def __init__(self):
    ...
    self.embed = ...
    self.mask = ...
    self.mask[0,:,5] = 1
    # Don't compute the masked version of 'embed' yet.

def forward(self, x):
    embed = self.embed * mask  # don't overwrite `self.embed`
    x = torch.cat([embed,x])  # use the newly created `embed` tensor

However any operation that will create a new node in the computation graph (i.e. any operation that involves a tensor with requires_grad=True) needs to happen in forward, so that the changes to the parameters can become effective.

great!
another question, hope itā€™s alright :slight_smile:
i notice my gpu utilization is declining rapidly thanks to this embedding procedure, as it causes some bottleneck along the way. can you think of any method of reducing it?

my actual input data is of size (batch=4,channel=1,width=75,height=95,depth=65) -
this means my self.embed is of size (category_num=155,channel=1,width=75,height=95,depth=65)
and so is self.mask.

the init func is now as so:

        if args.subject_embedding:
            self.do_embed = True
            self.embed = torch.randn((self.n_subjects, 1, self.dim[0], self.dim[1], self.dim[2]), requires_grad=True)
            self.mask = torch.ones(self.embed.shape,requires_grad=False)
            if args.embedding_mask is not None:
                self.mask = args.embedding_mask
                self.mask.requires_grad = False
                if self.mask.shape[0] == 1:
                    self.mask = self.mask.unsqueeze(0).repeat(self.n_subjects, 1, 1, 1, 1)
                else:
                    self.mask = self.mask.unsqueeze(0).unsqueeze(0).repeat(self.n_subjects, 1, 1, 1, 1)
            if args.cuda:
                self.mask = self.mask.cuda()
                self.embed = self.embed.cuda()

the forward pass is implemented as so:

    def forward(self, x, subj):

        #  Level 1 context pathway
        #print('input shape ',x.shape)
        context_0 = x
        if self.do_embed:
            embed = self.embed * self.mask
            E = embed[subj.to(dtype=torch.long)]
            x = torch.cat([x,E],dim=1)

still, since the only nonzero elements of self.mask are 40 per category, i thought this would reduce the bottleneck. but now i see it doesnā€™t and training will be slow.
notice that i push both the entire huge embedding to cuda and the huge ones tensor to cuda, where effectively iā€™m not sure that is necessary.
i was wondering if perhaps keeping them in cpu and only pushing to cuda the result of their multiplication (the tensor E ) in the forward pass, would reduce the bottleneck. i implemented it and couldnā€™t prove myself itā€™s more efficient yet.

P.S
do not be confused by the initiation of self.mask as a ones tensor, that is for the case the masking feature is turned off and then it operates as an identity transformation.

Hi, I try some codes:

x = torch.tensor(1., requires_grad=True)
y = torch.zeros((2, 2))
y[0, 1] += x
print(y[0, 1].requires_grad)
print(y[1, 1].requires_grad)

It should print True and False. However, it outputs both True.
I am very confused about it. Hope you can help me to explain it.