Constrain parameters to satisfy an equation in Pytorch


I’m trying to compute probabilities using parametric probability amplitudes in Pytorch for a variant of Dropout. I’ve implemented all the necessary complex arithmetic using the below code. However, my issue is that I need to constrain my tensors to satisfy the following equation: (1=(abs(x)**2)+(abs(y)**2), and I don’t know how I could do this in Pytorch.

My code:

class Complex(nn.Module):
    def __init__(self):
        super(Complex, self).__init__()
        self.real_param_a = nn.Parameter(torch.ones(1,1))
        self.imag_param_a = nn.Parameter(torch.ones(1,1))
    def abs_square(self):
        return (self.real_param_a ** 2) + (self.imag_param_a ** 2)

class QDropout(nn.Module):
    def __init__(self):
        super(QDropout, self).__init__()
        self.complex_a = Complex()
        self.complex_b = Complex()
    def forward(self, x):
        probability_one = self.complex_a.abs_square()
        probability_two = self.complex_b.abs_square()