The input to my network is of shape [batch_size, 1], and the output is of shape [batch_size, 3].
That is the network gets as input a single scalar and outputs a 3d vector.
My goal is to get the derivative of each of the elements in the output w.r.t to the input. I expect that the dimensions of the output will be [batch_size, 3] as well.
Currently, this is how I do it:
def d_network_d_input(network_obj, input_to_network):
"""
:param network_obj: netowrk object that returns outpout of shape: [batch_size, x, y, z]
:param input_to_network: shape [batch_size, T]
:return: Gradient of output w.r.t to input - [batch_size, 3]
"""
network_obj.zero_grad()
output_from_network = network_obj(input_to_network)
if input_to_network.grad is not None:
input_to_network.grad.data.zero_()
for single_batch in output_from_network:
output_x, _, _ = single_batch
output_x.backward(retain_graph=True)
dxdt = input_to_network.grad.clone()
input_to_network.grad.data.zero_()
for single_batch in output_from_network:
_, output_y, _ = single_batch
output_y.backward(retain_graph=True)
dydt = input_to_network.grad.clone()
input_to_network.grad.data.zero_()
for single_batch in output_from_network:
_, _, output_z = single_batch
output_z.backward(retain_graph=True)
dzdt = input_to_network.grad.clone()
return dxdt, dydt, dzdt, output_from_network
Is this the correct way, and is there any efficient way?
Thanks!