I’m trying to implement the following tensorflow script to pytorch, where values are passed through multiple networks but tape.gradient allows the compute gradient for specific networks
# tensorflow
with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape:
# persistent=True allow multiple graph
# watch_accessed_variables = False, allow graph only for networks defined in tape.watch()
tape.watch(first_network.variables)
tape.watch(second_network.variables)
y = first_network(x)
z = second_network(y)
# random stuff does not contain graph
first_loss = y + random_stuff_1
second_loss = -z + random_stuff_2
first_loss_grad = tape.gradient(first_loss, first_network.variables)
second_loss_grad = tape.gradient(second_loss, second_network.variables)
first_network_optimizer.apply_gradients(zip(first_loss_grads, first_network.variables))
second_network_optimizer.apply_gradients(zip(second_loss_grads, second_network.variables))
my pytorch implementation
# pytorch
y = first_network(x)
z = second_network(y)
# random stuff does not contain graph
first_loss = y + random_stuff_1
second_loss = -z + random_stuff_2
first_network_optimizer.zero_grad()
second_network_optimizer.zero_grad()
first_loss.bakward(retrain_graph=True)
second_loss.bakward()
first_loss.step()
second_loss.step()
let me know if the implementation seems right.