The forward function is complete, and is solely a number of tensor operations that transform the set of parameter inputs (matrix multiplication, argmax, concatenation etc.). It’s not something I could differentiate by hand.

How can I implement the .backward() function in this case? I want to obtain the gradients of all the inputs, and my understanding was that PyTorch could compute the derivatives automatically?

Apologies if I’m missing something very obvious here.

The function is simply as follows right now - I’ve simplified it slightly - the full ‘transformationFn’ exists within the forward function.
input: N x 4 tensor
params1: 1 x 2 tensor
params2: 1 x 2 tensor
transformationFn takes [input, params1, params2] → returns N x 4 tensor

If your function consists solely of (usefully-differentiable) pytorch
tensor operations you don’t need to package it as an official torch.autograd.Function nor write an explicit backward()
function; autograd will work “automatically” when you call backward()
on your final result (typically your loss).

Thanks very much KFrank! I had previously tried this, and I was obtaining ‘None’ for all but one of my gradients, so had thought that I had not implemented the function properly.

I will now look at the training loop to see what the issue is.