Neat way of temporarily disabling grads for a model?

I’m implementing a version of DDPG, and trying to calculate the policy loss.

# TODO: Should find grad for states, policy_actions but treat q_model params as constant
policy_loss = -self.q_model(states, policy_actions).mean()

I can do this by looping through q_model params, storing current values, setting requires_grad to False, calculating policy_loss, then restoring after.

orig_requires_grads = [p.requires_grad for p in self.q_model.parameters()]
for p in self.q_model.parameters(): p.requires_grad = False
policy_loss = -self.q_model(states, policy_actions).mean()
for p, rg in zip(self.q_model.parameters(), orig_requires_grads): p.requires_grad = rg

This feels clumsy. Is there a better way to temporarily disable a model like this? I looked for something like self.q_model.detach()(states, policy_actions).mean() but was surprised to find it doesn’t exist. Obviously I could make a helper.

I plan on moving this model over to TorchScript at some point, and I’m not sure whether this is going to work out there?

There are some old threads that indicate that manual looping is the best way, but I want to check in case any developments have happened here.


I don’t think there is any update. The for loop is simple and is the most efficient thing that can be done here.
Especially with your special logic of things already not requiring gradients, that would be tricky.

Note that you can add a method to your q_model module yourself to do that to make it a bit cleaner.

Thanks for letting me know. Yeah, what I’ve got for now is a with_grad param in the forward method of my q_model that does what I showed. I wasn’t sure if there was a way to prevent the grad propagating through the graph without setting state like that, but it seems like there isn’t so this will have to do. I would still prefer a declarative rather than procedural API here, but it’s obviously not important.

Well if you don’t want any gradient flowing, you can use torch.no_grad(). But that means that things above won’t get gradients either.

1 Like

In this case, the arguments given to the model should have their grads updated but the model’s own parameters shouldn’t. None of the no_grad systems allow for that specificity do they?

To explain the context, in DDPG we have a Q model which learns predicted rewards for state/action pairs, and a policy model which learns to choose the policy for a state. The Q model is trained from observed rewards. The policy model is trained by feeding its output into the Q model and doing gradient ascent to tweak the policy in the direction of a larger predicted reward. These both happen side-by-side in the same training loop.

The tricky bit is feeding the policy actions into the Q model to do the gradient ascent without also training the Q model. The requires_grad method works, but doing it in a more declarative way would require some way to tell it to only calculate grads back through a specific tensor or something like that. Like with torch.no_grad(except = policy_actions):

1 Like

Oops. It doesn’t particularly matter for the discussion, but there’s a little bug in the comment in my first post. It should only pass the grads back through policy_actions, not through states too. That’s a mistake that slipped in when simplifying the code to post on here, but apparently the forum won’t let me edit my first post. My real code does states.detach() there.

No they don’t. I just wanted to mention it in case future reader are in a different scenario.


I am new to pytorch.
So could anyone please help and clarify, if the code below would pause the computation graph at line 2(‘for p in self.q_model.parameters(): p.requires_grad = False’) and resume the graph after line 4(for p, rg in zip(self.q_model.parameters(), orig_requires_grads): p.requires_grad = rg)?

I am trying to implement a custom loss function, where I need to pre-process output of network which involves converting tensor to numpy and then converting it back to a tensor, then finally feed it to a log loss function. Will using the code above while pre-processing, help compute gradients correctly after the pre-processing part?