How to perform different forward and backward pass in PyTorch?

Hello, I am working on a problem where I’d like to perform the forward propagation in a specific way, and the backpropagation in another way. Since the computation graph for PyTorch is built when the ‘foward’ function is provided, then I assume PyTorch defines the ‘backward’ function as the opposite of the forward function(meaning it’s automatically defined). Im not really sure if I can change it and manually define the backward pass. Is there any way to do this?

Yes, you can define a custom backward function for your class as shown in this example: