Set requires_grad according to a binary mask

Hi all,
I have a conv2d layer parameters of shape say (x,y,z,w), and I have a binary mask of the same shape (x,y,z,w), I want to set the requires_grad attribute to every single weight parameter in this conv2d layer, or to be more clear I want to set x*y*z*w requires_grad values, i.e. for every weight value inside this conv filter acoording to the binary mask given. At places which have 1 in the binary mask, the requires_grad = True and vice versa. Is there a neat way of doing this other than looping over the entire filter?
Thanks for your help!
– Megh

Hi Megh!

There is no way to do what you want – with or without using loops.

The requires_grad property applies to the entire tensor, not to elements
of the tensor individually.

If your use case for this is to “freeze” just certain elements of the Conv2d
weight, that is, to only update some of the elements of weight when
training, the safest way is to store a copy of the elements prior to calling
the update, and then restore them after the update. E.g., something like
this:

with torch.no_grad():
    saved_weights = conv_layer.weight.clone()

# run training loop
...
    # inside of training loop
    opt.zero_grad()
    loss.backward()
    opt.step()
    with torch.no_grad():
        conv_layer.weight[mask.logical_not()] = saved_weights[mask.logical_not()]

Best.

K. Frank