Persistency during subsequent backward calls


I want to calculate a running mean based on the values of grad_output passed during backward function call. Is there a way to store it so that its there for subsequent backward calls for each batch? I looked at but it’s not clear where running_mean and running_var are updated. If I update ctx in backward() will it still be there for the forward() call.