Multiple model backpropagations in a loop

I’ve seen similar posts but their problem is different than mine and their solutions don’t seem to apply.

I have 3 models: source_generator, target_generator and a classifier (discriminator). The generators are feature extractors that process images and output an embedding. Each generator is intended to process only images from its domain, either source or target. The discriminator calssifies the embeddings into target or source labels. The goal of this training is to implement Adversarial Discriminative Domain Adaptation for the target_generator model.

The problem comes with the training loop:

 discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=args.learning_rate)
 generator_optimizer = optim.Adam(target_generator.parameters(), lr=args.learning_rate)

 loss = torch.nn.BCEWithLogitsLoss()

 # Train loop
 print('Training...')
 for epoch in range(args.epochs):
     # Train step
     for i, batch in enumerate(train_dl):
         inputs = batch['image']
         labels = batch['fake']
         
         # Separate the inputs by label
         source_index = labels == 1
         target_index = labels == 0
         source_inputs = inputs[source_index].float().to(device)
         target_inputs = inputs[target_index].float().to(device)

         # Zero grads
         discriminator_optimizer.zero_grad()

         # Process each input with corresponding generator
         source_outputs = source_generator(source_inputs)[0]
         target_outputs = target_generator(target_inputs)[0] 

         # Compute discriminator
         source_labels = torch.Tensor(np.array([np.ones(1) for _ in range(len(source_inputs))])).to(device)
         target_labels = torch.Tensor(np.array([np.zeros(1) for _ in range(len(target_inputs))])).to(device)

         source_loss = loss(discriminator(source_outputs), source_labels)
         target_loss = loss(discriminator(target_outputs), target_labels)
         discriminator_loss = source_loss + target_loss

         # Do backpropagation for discriminator
         discriminator_loss.backward(retain_graph=True)
         discriminator_optimizer.step()

         # Do backpropagation for generator
         target_loss.backward()
         generator_optimizer.step() 

I get RuntimeErrors related to a variable that has been modified:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [16, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

So I get the problem is something related with the fact that the discriminator model and the target_generator model are somehow related because the latter is used to compute the input of the first.
¿Is there another approach to this? ¿Can I somehow backpropagate only a model at a time?

You are most likely running into this issue trying to compute gradients with stale forward activations after a parameter update was performed.
The link shows a code snippet and describes the issues in more detail. Based on your code I would assume:

         # Do backpropagation for discriminator
         discriminator_loss.backward(retain_graph=True)
         discriminator_optimizer.step()

         # Do backpropagation for generator
         target_loss.backward()

fails, since target_loss.backward() would need the original parameters of the discriminator to compute the gradients, while these were already updated in discriminator_optimizer.step() thus creating stale intermediate forward activations.

Hi David!

There are a couple of different things going on here.

Most importantly, it looks like you are training your target_generator backwards,
see below.

First, @ptrblck’s explanation of the inplace-modification error is almost
certainly correct.

         discriminator_optimizer.step()

performs an inplace modification of discriminator’s parameters.

With:

         target_loss.backward()

you then backward through discriminator again. It is highly likely that at least
some of discriminator’s parameters will be needed to perform that backward
pass, so, because they’ve been modified, you get the error.

But what to do about it?

Here’s the issue about training target_generator “backwards:”

target_generator is supposed to generate a target_outputs that is hard
for discriminator to tell apart from source_outputs.

If discriminator correctly recognizes that target_outputs is “fake,” it will
produce zeros (or small values) for its predictions for target_labels, which
are zeroes. If it does so, it will make target_loss small. Using target_loss
to train discriminator makes sense, as a small value for target_loss
“rewards” discriminator for identifying target_outputs as fake.

But we want to penalize target_generator for having produced a
target_outputs that discriminator recognized as fake. (This is the
“adversarial” part.)

By also training target_generator with target_loss we actually reward
target_generator for producing a detectably-fake target_outputs. This
is backwards of what we want.

As an aside, yes you can perform your “backwards” logic in a way where
you (more or less) backpropagate one model at a time. (But, as explained
above, this is not what you want to do.)

Here’s how:

         # Process each input with corresponding generator
         source_outputs = source_generator(source_inputs)[0]
         target_outputs = target_generator(target_inputs)[0] 

         # Compute discriminator
         source_labels = torch.Tensor(np.array([np.ones(1) for _ in range(len(source_inputs))])).to(device)
         target_labels = torch.Tensor(np.array([np.zeros(1) for _ in range(len(target_inputs))])).to(device)

         # Zero grads
         generator_optimizer.zero_grad()
         discriminator_optimizer.zero_grad()

         target_loss = loss(discriminator(target_outputs), target_labels)

         # Do backpropagation for generator and PART OF discriminator
         target_loss.backward()
         generator_optimizer.step() 

This accumulates the gradients of target_loss into the parameters of both
target_generator and discriminator and performs an optimization step
on target_generator, modifying it inplace. The graph for target_loss is
no longer needed, so it is not retained.

         source_loss = loss(discriminator(source_outputs), source_labels)

         # Do backpropagation for THE REST OF discriminator
         source_loss.backward()
         discriminator_optimizer.step()

The parameters of discriminator still have their target_loss gradients.
This now accumulates the source_loss gradients into discriminator’s
parameters (so the total gradient becomes that of discriminator_loss
which is not explicitly computed in this version) and performs an optimization
step.

source_loss does not depend on target_generator so
source_loss.backward() does not backpropagate anything through
target_generator. Therefore no inplace-modification error occurs,
even though target_generator had been modified inplace.

Even though you don’t want to do this, it’s worth understanding how this
scheme implements your “backwards” logic without an inplace-modification
error.

Now to what you most likely actually want to do:

At the cost of applying discriminator twice to target_outputs we will
compute two versions of target_loss – one that rewards target_generator
for fooling discriminator and a second that rewards discriminator for not
getting fooled.

         # Process each input with corresponding generator
         source_outputs = source_generator(source_inputs)[0]
         target_outputs = target_generator(target_inputs)[0]

         target_labelsA = torch.Tensor(np.array([np.ones(1) for _ in range(len(target_inputs))])).to(device)    # reward target_generator for fooling discriminator
         source_labels = torch.Tensor(np.array([np.ones(1) for _ in range(len(source_inputs))])).to(device)
         target_labelsB = torch.Tensor(np.array([np.zeros(1) for _ in range(len(target_inputs))])).to(device)   # reward discriminator for not getting fooled

         # Zero grads
         generator_optimizer.zero_grad()

         target_lossA = loss(discriminator(target_outputs), target_labelsA)
         target_lossA.backward()
         generator_optimizer.step()

         # clear out target_lossA grads which we don't want
         discriminator_optimizer.zero_grad()

         # detach so that we will not backpropagate through target_generator
         target_outputs = target_outputs.detach()

         source_loss = loss(discriminator(source_outputs), source_labels)
         target_lossB = loss(discriminator(target_outputs), target_labelsB)
         discriminator_loss = source_loss + target_lossB

         # Do backpropagation for discriminator
         discriminator_loss.backward()
         discriminator_optimizer.step()

To recap: This scheme uses two versions of target_loss – one to train
target_generator and a second to train discriminator. Because these
two versions are, in a sense, the opposite of one another, this logic correctly
trains target_generator to fool discriminator, rather than training
target_generator backwards.

Best.

K. Frank

Thank you so much @ptrblck @KFrank for your help and for taking the time to answer so clearly and helping me understand not only the solution but the problem itself. This helped me a lot, and for a first experience using forums this was excellent.

@KFrank The target_generator has a custom GradientInversor function which is supposed to invert the gradients at the beginning of the backwards pass. This was intended to make the error of the target_generator -= the error of the discriminator.

class GradientInversor(torch.autograd.Function):
    '''Gradient inversor to make the feature extractor loss
    the negative of the discriminator loss.'''
    @staticmethod
    def forward(self, x):
        return x
    
    @staticmethod
    def backward(self, grad_output):
        # Retorna el gradiente inverso
        return -grad_output

I was worried about this approach and I think the solution given makes a lot of sense, thank you so much, this will help me a lot for my Masters Thesis.

Hi David!

Well that’s different. Using your GradientInversor is a logically sound scheme
to train target_generator in the “right” direction. I think the approach that I
outlined in my previous post where you first optimize target_generator and
then call source_loss.backward() should therefore work.

Having said that, using GradientInversor trains `target_generator with a loss
function, in essence, of:

target_lossBNegative = -loss(discriminator(target_outputs), target_labelsB)

(where target_labelsB are zeros).

On the other hand (without GradientInversor), my training proposal trains
target_generator with

target_lossA = loss(discriminator(target_outputs), target_labelsA)

(where target_labelsA are ones).

While target_lossA and target_lossBNegative play the same role and have
some similarities, they are not equal to one another.

I have never tested the two against one another, but my intuition tells me that
target_lossA is likely to work better because it has a logarithmic divergence
when target_generator is not doing well in fooling discriminator. This sort
of logarithmic divergence seems to me to be very helpful in effective training.

If you have the time and energy, it might be useful to compare the two methods
head-to-head on your specific concrete problem and see if one or the other is
clearly superior.

Good luck!

K. Frank

1 Like

What you’ve done here gives me an error unless I specify retain_graph=True for target loss. However you’ve specified that it is not necessary, can you please explain?

# Process each input with corresponding generator
         source_outputs = source_generator(source_inputs)[0]
         target_outputs = target_generator(target_inputs)[0] 

         # Compute discriminator
         source_labels = torch.Tensor(np.array([np.ones(1) for _ in range(len(source_inputs))])).to(device)
         target_labels = torch.Tensor(np.array([np.zeros(1) for _ in range(len(target_inputs))])).to(device)

         # Zero grads
         generator_optimizer.zero_grad()
         discriminator_optimizer.zero_grad()

         target_loss = loss(discriminator(target_outputs), target_labels)

         # Do backpropagation for generator and PART OF discriminator
         target_loss.backward()
         generator_optimizer.step()
         source_loss = loss(discriminator(source_outputs), source_labels)

         # Do backpropagation for THE REST OF discriminator
         source_loss.backward()
         discriminator_optimizer.step()