Backward() with manipulation on weight tensor in training loop

I am trying to implement a training loop, in which I only want to update a subset of weights of a weight matrix. According to this discussion, or this, it is possible to let the optimizer do the step and then manually set some parts of the weights to their initial values.

A toy example is this: Here I try to use ADAM on the parameter x, but want to keep its last column. Therefore, I save the initial values as xsave and set the values by torch.where(...).

x = torch.nn.Parameter(torch.randn(2, 3))
  
xsave = x.clone().detach()
mask_prevent_update = torch.full(x.shape, False)
mask_prevent_update[:, 2] = True
  
def loss_func():
    return torch.sum(x**2)
print(loss_func())
  
optimizer = torch.optim.Adam([x], lr=.1)
  
for i in range(100):
    optimizer.zero_grad()
    loss = loss_func()
    loss.backward()
    optimizer.step()
    x = torch.where(mask_prevent_update, xsave, x)

However, this operation causes the error “RuntimeError: Trying to backward through the graph a second time”

How can I avoid the error and still do the reset-operation?

Hi Lukas!

This last line sets the python variable x to refer to the new tensor created by
torch.where(). It doesn’t modify the Parameter to which x originally referred.

Try:

    with torch.no_grad():
        x.copy_ (torch.where (mask_prevent_update, xsave, x))

Best.

K. Frank