Adversarial network feature decorrelation: passing loss gradient

Hi all! I’m trying to construct an ANN to decorrelate a classifier from a given feature. I think the best way to show is with a little sketch (I hope it renders):


(the constant can be modified to balance classifying and feature decorellation)

The problem I have is how I pass the ANN loss to the classifier so it can calculate the gradient with respect to the classifier outputs, without modifying the ANN itself. What I have now (clearly still wrong) is something like this:

        #run number of batches
        for i, data in enumerate(trainloader, 0):
            #regtargets are the x feature values that the regressor tries to find
            inputs, targets, regtargets = data
            inputs, targets, regtargets = inputs.float(), targets.float(), regtargets.float()
            targets = targets.reshape((targets.shape[0], 2))
            #reset optimizers
            optimizer_reg.zero_grad()
            optimizer.zero_grad()

            #get classifier predictions
            outputs = mlp(inputs)
            
            #outputs are fed to ANN (detach, Error:args don't support auto diff, but one requires grad)
            ANN_in = outputs.detach()
            ANN_in.requires_grad = False
            #optimise ANN to new outputs
            ANN = TrainANN_hot(ANN, ANN_in, regtargets)
            ANNoutputs = ANN(ANN_in)
            loss_ANN = loss_function2(ANNoutputs, regtargets)
            loss_ANN.backward()
            optimizer_reg.step()
            optimizer_reg.zero_grad()
            #get new loss after optimization
            ANN_in.requires_grad = True
            ANNoutputs = ANN(ANN_in)
            loss_ANN = loss_function2(ANNoutputs, regtargets)
            
            #loss is calculated and complete model is retrained
            loss = loss_function(outputs, targets) - lda* loss_ANN
            loss.backward()

I’m new to pytorch so I hoped to get some insight on how to best implement gradient reversal.

I eventually found a much more elegant solution here: [Solved] Reverse gradients in backward pass