I have a couple of custom torch.autograd.functions that perform some bigger computations within the .forward and .backward methods. For visualization, it would be very nice to plot some of the intermediate results which are computed in the forward and backward methods. So far, whenever I was doing some calculations I just added the scripts that create the plots directly into the .forward and .backward methods but that is quite inflexible and seems suboptimal. Whats the best practice here?
This is how I did the plotting so far, changing the plots is definitely suboptimal since whenever I want to change anything in the plotting, I have to go into the class and change the code there:
class some_function(torch.autograd.Function): @staticmethod def forward(ctx, input1, input2): output1 = comp1(input1, input2) ctx.save_for_backward(input1) ctx.save_for_backward(input2) return output1 @staticmethod def backward(ctx, grad_output1): input1, input2 = ctx.saved_tensors() intermediate_result = comp2(grad_output1, input1) plt.plot(input1, intermediate_result) grad_input2 = comp3(intermediate_result, input2) return None, grad_input2