Normalize Gradients of loss function w.r.t input before backpropagating

I have been trying to implement this paper for the adversarial generation of handwritten images from a textual string.

My model comprises of two networks

  1. A CGAN architecture comprising of an encoder, Generator and a Discriminator. The encoder encodes the raw textual string into a high-dimensional feature vector while the decoder which is actually the generator takes the encoded string and a noise vector and tries to output fake word images. On the other hand, the discriminator is trained to tell real and fake images apart which is the norm in a standard GAN architecture.
  2. Besides, the above components my network also comprises of a pre-trained recognizer (OCR) module which takes an input image and outputs a string. The OCR is an attention based encoder-decoder network and the loss function used is NLL. The reason for providing a recognizer in addition to the discriminator is that I want the content of the image to be a word. The Discriminator alone will not work since it can only tell whether an image looks real or fake.

Since there are two different architectures and two different loss functions, the gradients of each loss function w.r.t input vary in magnitude. As mentioned in the paper (section3.4) the norm of the gradient of NLL loss w.r.t to input is 10^2 time the gradient of discriminative loss w.r.t to the input. Thus I am not able to train the network properly.

The paper mentions a method for balancing the gradients of the recognizer w.r.t to the discriminator by multiplying it with a scalar such that both the gradients lie in the same range before back propagating it to the generator.

However, I am not able to understand as to how do I do this at an implementation level.


   fake_images = self.gen.decode(decoder_input)
   loss_recog = self.nll_criterion(fake_images, targets)
   loss_gen = self.disc.calc_gen_loss(fake_images)
   loss_gen_total = loss_gen + los_recog

I have tried to modify the gradients w.r.t to fake image by using the command fake_images.grad(). However it returns None. This is after I have invoked fake_images.requires_grad = True and also `fake_images.retain_grad = True.

Any suggestions on how to modify the gradients of one loss w.r.t. to other will be highly appreciated.