Calculate the gradient of partial input in neural network module


I met a issue of calculating the gradient for one part of the input and don’t know how to specify the input to calculate the gradient. The problem is as follows:

Basically, I would like to implement the reinforcement learning algorithm DDPG in pytorch and one step of the algorithm is to calculate the gradient of Q(S,A) towards A. So the network has (S, A) as input and I need to calculate the gradient of Q towards only A. The neural network for Q(S,A) is as follows:

Is there any comments or suggestions about this issue? I would really appreciate it!