I am trying to implement a training loop, in which I only want to update a subset of weights of a weight matrix. According to this discussion, or this, it is possible to let the optimizer do the step and then manually set some parts of the weights to their initial values.
A toy example is this: Here I try to use ADAM on the parameter x
, but want to keep its last column. Therefore, I save the initial values as xsave
and set the values by torch.where(...)
.
x = torch.nn.Parameter(torch.randn(2, 3))
xsave = x.clone().detach()
mask_prevent_update = torch.full(x.shape, False)
mask_prevent_update[:, 2] = True
def loss_func():
return torch.sum(x**2)
print(loss_func())
optimizer = torch.optim.Adam([x], lr=.1)
for i in range(100):
optimizer.zero_grad()
loss = loss_func()
loss.backward()
optimizer.step()
x = torch.where(mask_prevent_update, xsave, x)
However, this operation causes the error “RuntimeError: Trying to backward through the graph a second time”
How can I avoid the error and still do the reset-operation?