Compute weight gradients for first forward pass and input gradients for the rest

Hi,

I’m trying to implement a more efficient iterative method similar to adversarial attacks.

For each model I want to compute standard loss for weight updates, but I want to add a penalized loss before performing optimizer.step(). Computing the penalized loss requires iteratively computing the input gradients, but I don’t need to store the weight gradients for each step in this iterative phase.

Is it possible to perform loss.backward() after the first forward pass, and freeze the weight gradients so that computing gradients with respect to inputs doesn’t modify them (i.e. ch.autograd.grad(model(x), [x]) doesn’t change model.arbitrary_layer.weight.data?).

At the end of the iterative phase a loss function will be called again however this time I want the weight gradients to be accumulated with the previous loss.

logits, otheroutput = model(x, return_otheroutput=True)
loss = loss_fn(logits, x_label)
loss.backwards()

z = x.clone().detach().requires_grad_(True)
for i in range(iterations):

    z = z.clone().detach().requires_grad_(True)

    logits = model(z, return_otheroutput=False)

    loss = loss_fn(logits, other_label)

    grad = torch.autograd.grad(loss, [z])

    with torch.no_grad():
        z = compute_new_x(z, grad)

logits = model(z, return_otheroutput=False)
loss = loss_fn(z, other_label)
loss = -2*loss
loss.backwards()
optimizer.step()

The end loss function should be <standard_loss> - 2<loss of adversarial image>. In this case I say its adversarial image but not really. It’s quite similar though. Also the -2 is a standard in for the usual lambda scalar that is attached to some penalized loss functions

Hi,

You can use autograd.grad to get the gradient wrt a specific set of Tensors without touching any .grad field.
But you seem to be doing that properly in your sample.

Few things:

  • you want .backward() (without the s)
  • The way you compute your loss does not match the formula you gave
  • This code should do what you want.
1 Like