Hi all,
I have a conv2d
layer parameters of shape say (x,y,z,w)
, and I have a binary mask of the same shape (x,y,z,w)
, I want to set the requires_grad
attribute to every single weight parameter in this conv2d
layer, or to be more clear I want to set x*y*z*w
requires_grad
values, i.e. for every weight value inside this conv filter acoording to the binary mask given. At places which have 1 in the binary mask, the requires_grad = True
and vice versa. Is there a neat way of doing this other than looping over the entire filter?
Thanks for your help!
– Megh
Hi Megh!
There is no way to do what you want – with or without using loops.
The requires_grad
property applies to the entire tensor, not to elements
of the tensor individually.
If your use case for this is to “freeze” just certain elements of the Conv2d
weight
, that is, to only update some of the elements of weight
when
training, the safest way is to store a copy of the elements prior to calling
the update, and then restore them after the update. E.g., something like
this:
with torch.no_grad():
saved_weights = conv_layer.weight.clone()
# run training loop
...
# inside of training loop
opt.zero_grad()
loss.backward()
opt.step()
with torch.no_grad():
conv_layer.weight[mask.logical_not()] = saved_weights[mask.logical_not()]
Best.
K. Frank