MinMax Adversarial Loss

I have a multi-task learning model with two multi classification tasks. One part of the model creates a shared feature representation that is fed into two subnets in parallel. The loss function for each subnet at the moment is cross entropy.
I want to minimize CE in one task and to maximise the cross entropy in one task so the model doesn’t/can’t learn anything about that one task, and then I think the resulting accuracy for that task should be 1/numoflabels i.e no better than random guess

In my network. I have two classifiers as:

class Adversarial(nn.Module):
    def __init__(self):
        super(Adversarial, self).__init__()       
        self.cnn = CNN() 
        self.emotion_classfier = EmotionClassfier()
        self.speaker_invariant = SpeakerInvariant()

    def forward(self, input):
        input         = self.cnn(input)
        input_emotion = self.emotion_classfier(input)
        input_speaker = self.speaker_invariant(input) 
       """First confusion is here, how shall handle it here and then how to implemenet this loss function  """ 
        return input_stutter, input_speaker

class EmotionClassfier(nn.Module):
    def __init__(self):
        super(EmotionClassfier, self).__init__() 
        self.fc1 = nn.Linear(128, 64)
        self.fc2 = nn.Linear(64, 5)     #5 here is number of emotion classes 
        self.relu = nn.ReLU()

    def forward(self, input):
        input = input.view(input.size(0), -1)
        input = self.fc1(input)
        input = self.relu(input)
        input = self.fc2(input)
        return input

class SpeakerInvariant(nn.Module):
    def __init__(self):
        super(SpeakerInvariant, self).__init__()
        self.fc1 = nn.Linear(128, 64)
        self.fc2 = nn.Linear(64, 11)   #11 here is number of speakers/labels
        self.relu = nn.ReLU()

    def forward(self, input):
        input = input.view(input.size(0), -1)
        input = self.fc1(input)
        input = self.relu(input)
        input = self.fc2(input)       
        return input

and the loss function I want to implement is as follows:

criterion = nn.CrossEntropyLoss()  #For EmotionClassfier as well as SpeakerClassfier
loss_motion   = criterion(output, emotion_label)  #Want to Minimise this loss
loss_speaker = criterion(output, speaker_label)
"""Here I want to maximize the speakerloss,(Suppose there are five speaker labels and I want to learn the speaker independent features)
i.e I want to apply Gradient Reversal for SpeakerInvariant Network"""
loss      = loss+emotion + lambda * loss_speaker 

Could you please help it out here

From paper:

https://www.microsoft.com/en-us/research/uploads/prod/2018/04/ICASSP2018_Speaker_Invariant_Training.pdf

Hi Shakeel!

I don’t know that much about adversarial networks, and I haven’t
looked at the paper to which you linked, but let me outline some
details that might be relevant to what you are trying to do:

Your network comes in three pieces, the front end that produces your
“shared feature representation,” let’s call it FeatureNet, the one half
of your back end that predicts emotions, call it EmotionNet, and the
other half that predicts the speakers, `SpeakerNet’.

To be clear, you don’t want to train SpeakerNet to do poorly
predicting the speakers. (That would be trivial.)

The core goal of your adversarial network is to train FeatureNet
to generate features that that can be used to successfully predict
emotions, but that are not useful for predicting the speakers.

The key idea is that we actively train SpeakerNet to do the best
job it can predicting the speakers using the output of FeatureNet,
but train FeatureNet so that SpeaketNet does poorly (while
EmotionNet does well).

We introduce two loss functions, emotionLoss and speakerLoss, and
the combined loss, loss = emotionLoss + lambda * speakerLoss.
(In your case the two losses are both CrossEntropyLoss.)

We backpropagate / optimize loss normally through both EmotionNet
and SpeakerNet, so that we train the both to make their predictions as
successfully as they can, given the features produced by FeatureNet.
(Notice that emotionLoss doesn’t depend on SpeakerNet, so
backpropagating the gradient of emotionLoss through SpeakerNet
doesn’t affect how SpeakerNet’s weights are updated. And vice
versa.)

But we keep track of the emotionLoss and speakerLoss gradients
separately so that before we backpropagate through FeatureNet
we can flip the sign of the speakerLoss gradient. (This would be
your “Gradient Reversal.”)

You are now training FeatureNet to produces features that let
EmotionNet be trained to do well predicting emotions, but also
prevent SpeakerNet from doing well predicting the speakers,
even though the SpeakerNet portion of your whole network is
being trained to do as well as it can, given the features it gets
from FeatureNet.

(I’m not aware of anything built into pytorch that will do “Gradient
Reversal” for you, and I’m not really sure how best to implement
something like this. In any event, at the point where the gradients
are being backpropagated from EmotionNet and SpeakerNet back
through FeatureNet, you somehow have to get your hands on the
speakerLoss / SpeakerNet gradient and flip its sign, before passing
it on back through FeatureNet.)

Good luck.

K. Frank

@ptrblck Can you please help here.

I think @KFrank provided a great answer. Do you have specific questions and did you try to follow the suggestions?

@ptrblck
Yes the specific question is regarding gradient reversal
How do we flip the gradients here ?
Can you please provide an example with SpeakerNet

@ptrblck
I am trying to apply the Gradient reversal layer (GRL) from Speaker Invariant Training using GRL and Unsupervised Domain Adaptation by Backpropagation in the SpeakerNet as mentioned below.
The idea is to learn speaker invariant emotion features by adversarial training using GRL i.e to unlearn speaker information.

Is my implementation correct for GRL here and in main training function, do I have to simply sum the losses of EmotionNet and SpeakerNet and then back-propagate ?

class EncoderNet(nn.Module):
    def __init__(self):
        super(Multitask, self).__init__()       
        self.cnn = TDNN()   #Assume Shared Encoder Here 

    def forward(self, input):
       return self.cnn(input)

        
class Multitask(nn.Module):
    def __init__(self):
        super(Multitask, self).__init__()       
        self.enc_net               =  EncoderNet()
        self.emotion_classfier = EmotionClassfier()
        self.speaker_invariant = SpeakerInvariant()

    def forward(self, x):
        input = self.enc_net(x)
        input_emotion = self.emotion_classfier(input)
        input_speaker = self.speaker_invariant(input) 
       """First confusion is here, how shall handle it here and then how to implemenet this loss function  """ 
        return [input_emotion, input_speaker]

class EmotionClassfier(nn.Module):
    def __init__(self):
        super(EmotionClassfier, self).__init__() 
        self.fc1 = nn.Linear(128, 64)
        self.fc2 = nn.Linear(64, 5)     #5 here is number of emotion classes 
        self.relu = nn.ReLU()

    def forward(self, input):
        input = input.view(input.size(0), -1)
        input = self.fc1(input)
        input = self.relu(input)
        input = self.fc2(input)
        return input

class SpeakerInvariant(nn.Module):
    def __init__(self):
        super(SpeakerInvariant, self).__init__()
        self.fc1 = nn.Linear(128, 64)
        self.fc2 = nn.Linear(64, 10)   #10 here is number of speakers/labels
        self.relu = nn.ReLU()
   
    def forward(self, input):
        input_reverse_gradients = ReverseGradients.apply(input, 0.5)  #alpha is 0.5
         input = self.relu(self.fc1(input_reverse_gradients))
         return self.fc2(input)
"""For GRadient Reversal Layer """
from torch.autograd import Function

class ReverseGradients(Function):
    @staticmethod
    def forward(ctx, x, lamda):
        ctx.lamda = lamda
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.lamda
        return output, None