How am I supposed to pass a network module to a function?

I want to calculate the derivative to the input vector, and adding noise to it.
so how am I supposed to pass a trained network module to a function?
And can I just do one forward pass, then use torch.autograd.backward to get the derivatives input.grad.
Is there any more elegant way for this?

Hi,

So how am I supposed to pass a trained network module to a function?

I’m not sure to understand the question. You can pass your model as any other parameter to the function.

Is there any more elegant way for this?

You can use torch.autograd.grad() to get the gradients wrt the input directly. (And avoid populating the .grad fields of all the weights in your model).

Thanks for your reply.

You can use torch.autograd.grad() to get the gradients wrt the input directly. (And avoid populating the .grad fields of all the weights in your model).

I checked the usage of torch.autograd.grad(), but I’m still confused.
Suppose there is a input vector input, and a loss L, I want to calculate the grad on the input rather than the whole network wrt L. Can I write input_grad = torch.autograd.grad(L, input)?

After getting the grad using
input_grad = torch.autograd.grad(L, input)
, I got
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
when I tried to update the network using
optimizer_d.zero_grad()
d_loss.backward()
optimizer_d.step().

Why is this happening and how am I suppose to fix this? Thank you.

To my understanding input_grad = torch.autograd.grad(L, input) calculates the grads wrt input, and d_loss.backward(), optimizer_d.step() are updating weights with their grads. But there is a counter counting how many times you have derived from the network, and it’s not allowed in Pytorch?

In my case, if I don’t want to update the trained network and just want to get the grads wrt input, I can used torch.autograd.grad without problem above?

Hi,

This error happens because we try to reduce memory usage and so free all the temporary buffer when you do a backward pass (because in most cases you only do one).
If you plan on doing more than on (using backward or grad), all but the last one should be called with retain_graph=True.