Gradcheck fails for custom activation function

I am implementing this activation function:

class EOActivation(Function):
    '''
    Electro-optic activations as described in - 
        {Williamson, Ian AD, et al. "Reprogrammable electro-optic nonlinear 
        activation functions for optical neural networks." IEEE Journal of 
        Selected Topics in Quantum Electronics 26.1 (2019): 1-12.}
    '''
    @staticmethod
    def forward(Z: Tensor,
                alpha: Tensor,
                g: Tensor,
                phi_b: Tensor) -> Tensor:
        '''
        Forward-pass of EO nactivation function
        Z: tensor, Input tensor
        alpha: tensor, parameter 'alpha'
        g: tensor, parameter 'g'
        phi_b: tensor, parameter 'phi_b'
        '''
        return 1j * torch.sqrt(1 - alpha) * torch.exp(
            -1j*0.5*g*torch.conj(Z)*Z - 1j*0.5*phi_b) * torch.cos(
                0.5*g*torch.conj(Z)*Z + 0.5*phi_b) * Z
    
    @staticmethod
    def setup_context(ctx, inputs, output):
        '''
        ctx: Context object
        inputs: Inputs are the inputs to forward()
        output: Output tensor of forward()
        '''
        # Save parameters and output of forward for backward pass
        input, alpha, g, phi_b = inputs
        ctx.save_for_backward(input, alpha, g, phi_b)
    
    @staticmethod
    def backward(ctx, grad_Z: Tensor) -> Tensor:
        '''
        ctx: context object
        grad_Z: backpropagated gradient signal from (l+1)th layer
        '''
        # get the parameters and output field computed during forward-pass
        Z, alpha, g, phi_b = ctx.saved_tensors
        zR, zI = Z.real, Z.imag
        # df_dRe - Gradient w.r.t. real part of the input
        df_dRe = torch.sqrt(1 - alpha) * torch.exp((-0.5*1j)*g*(zR - 1j*zI)*(zR + 1j*zI) - (
            0.5*1j)*phi_b) * (zR*g*(zI - 1j*zR) * torch.sin(0.5*(zR**2)*g + 0.5*(
                zI**2)*g + 0.5*phi_b) + ((zR**2)*g + 1j*zR*zI*g + 1j) * torch.cos(
                    0.5*(zR**2)*g + 0.5*(zI**2)*g + 0.5*phi_b))
        #df_dIm - Gradient w.r.t. imaginary part of the input
        df_dIm = torch.sqrt(1 - alpha) * torch.exp((-0.5*1j)*g*(zR - 1j*zI)*(zR + 1j*zI) - (
            0.5*1j) * phi_b) * (zI*g*(zI - 1j*zR) * torch.sin(0.5*(zR**2)*g + 0.5*(
                zI**2)*g + 0.5*phi_b) + (zR*zI*g + 1j*(zI**2)*g - 1) * torch.cos(
                    0.5*(zR**2)*g + 0.5*(zI**2)*g + 0.5*phi_b))
        # Return the gradient and 'None' for parameters in forward()
        return (grad_Z*df_dRe).real - 1j*(grad_Z*df_dIm).real, None, None, None

where alpha, g and phi_b are constant parameters initialized as tensors with “requires_grad=False”. The function is applied in a layer class whose forward() looks like this:

def forward(self, Z: Tensor) -> Tensor:
        '''
        Z: Input tensor from (l-1)th layer
        Z_out: Output tensor after forward propagation (activation)
        '''
        Z = EOActivation.apply(Z, self.alpha, self.g, self.phi_b)       
        if self.photodetect:
            Z = Z.square()
        if self.bias is not None:
            Z = Z + self.bias.unsqueeze(0)

        return Z

let’s say “photodetect” and “bias” are always “False”.

My problem is that the activation does not work as expected. When used as part of a neural network, it fails to learn. I am compraing the equations with another implementation and the calculations seem to be correct. My gradcheck output looks like this:


Looks like there is a sign mismatch. Can anyone please tell me if I am missing any properties in my implementation? Or if I need to take any additional care for handling the gradients?

This might be a useful read if you’re dealing with derivatives involving complex numbers Autograd mechanics — PyTorch 2.5 documentation

1 Like

Thank you very much for the resource. It is valuable. I understand that the computation of derivatives of funcitons of complex numbers is different from how it is done in the reference package that I was using. For now, from my initial understanding, I made a quick fix to my derivatives which works but I do not know why. I need to do that derivations myself following this guide that you shared. I will mark your comment as the solution so that the topic can be closed. Thank you again.

After going through pytorch’s gradient computations I see that the gradient that should be returned by this layer is not:

(grad_Z*df_dRe).real - 1j*(grad_Z*df_dIm).real

but,

torch.conj(grad_Z) * 0.5 * (df_dRe + 1j * df_dIm) + grad_Z * 0.5 * torch.conj(df_dRe - 1j * df_dIm)

However, as the function used in forward computation is derivable and as it can be written using only pytorch ops, a custom backward implementation is not really necessary in this case.