Accessing and plotting data from custom autograd.function - best practice


I have a couple of custom torch.autograd.functions that perform some bigger computations within the .forward and .backward methods. For visualization, it would be very nice to plot some of the intermediate results which are computed in the forward and backward methods. So far, whenever I was doing some calculations I just added the scripts that create the plots directly into the .forward and .backward methods but that is quite inflexible and seems suboptimal. Whats the best practice here?

This is how I did the plotting so far, changing the plots is definitely suboptimal since whenever I want to change anything in the plotting, I have to go into the class and change the code there:

class some_function(torch.autograd.Function):

     def forward(ctx, input1,  input2):
          output1 = comp1(input1, input2)
          return output1

     def backward(ctx, grad_output1):
          input1, input2 =  ctx.saved_tensors()
          intermediate_result = comp2(grad_output1, input1)
          plt.plot(input1, intermediate_result)
          grad_input2 = comp3(intermediate_result, input2)
          return None, grad_input2

Your approach looks correct since it seems you need to access the intermediates in the backward call. If you would only want to plot the data or weight gradients you should be able to use backward hooks. However, these hooks wouldn’t allow you to access the intermediates.