Pytorch JVP slow

Using functional.jvp ( and taking gradient with respect to it is 2 times slow in general and for larger networks (ResNet-20), it is 4 times slower. Are there any known issues on using JVP? Are there any methods to make it more efficient (maybe using other packages)?


As mentioned in the note on the function definition, the function uses the “double backward trick” to compute the jvp and is thus expected to be slower.
We are working on changing that for 1.7.

1 Like

Thank you for the reply! That is great to hear :slight_smile: . When do you expect 1.7 to be released?

It should be early November IIRC.

Are these changes already available in pytorch-nightly?


I am afraid this got delayed and won’t be in 1.7.
But they should get into master soon after we are done working on the release (mid/end November)


Thank you for your answer. Is there an issue I can follow to know when this is merged into master?

Sure, you can follow that contains most of the discussions around this.