NaN's in gradients due to multi-objective loss function

Context

I’m currently trying to implement the following architecture:

Where x is an audio signal, h_\theta is a linear filter, f_\phi is a network that predicts one of two classes, either y_hat = 0 or y_hat = 1, and g_\psi is an autoencoder that attempts to reconstruct x as x_\hat.

Here is how I pass a batch of x’s through the filter h_\theta:

x = filt(x)

Where filt is a nn.Module object with trainable parameters:

filt = Filter(bias = True, kernel_size = 201, identity = False).to(device)

After passing x through the filter h_\theta, here is how I pass it through the network f_\phi:

def run_batch_disc(mode,x,labels,transforms,net,loss_func,device):
    
    log, mel, spec = transforms
    
    # compute logarithmic Mel-scale magnitude spectrograms of the filtered
    # signals
    
    x = spec(x)
    x = torchaudio.functional.complex_norm(x)
    x = mel(x)
    x = log(x)
    
    # with torch.set_grad_enabled(mode == 'train'):        
    
    # logits must have the same shape as labels
    
    logits = net(x).squeeze(dim = 1)
    
    # compute negative log-likelihood (NLL) using logits
    
    NLL = loss_func(logits,labels)
    
    # record predictions. since sigmoid(0) = 0.5, then negative values
    # correspond to class 0 and positive values correspond to class 1
    
    preds = logits > 0
    
    # record correct predictions
    
    true_preds = torch.sum(preds == labels)
    
    return NLL,true_preds.item()

Where log, mel, and spec are:

log = torchaudio.transforms.AmplitudeToDB(stype = 'magnitude',
                                          top_db = 80).to(device)

# settings for spectrograms

win_length_sec = 0.012
win_length = int(sample_rate * win_length_sec)
# need 50% overlap to satisfy constant-overlap-add constraint to allow
# for perfect reconstruction using inverse STFT
hop_length = int(sample_rate * win_length_sec / 2) # used to be 64
n_mels = 128
n_fft = 512 # used to be 1024

mel = torchaudio.transforms.MelScale(n_mels = n_mels,
                                     sample_rate = sample_rate,
                                     f_min = 0.0,
                                     f_max = None,
                                     n_stft = n_fft // 2 + 1).to(device)

spec = torchaudio.transforms.Spectrogram(n_fft = n_fft,
                                         win_length = win_length,
                                         hop_length = hop_length,
                                         pad = 0,
                                         window_fn = torch.hann_window,
                                         power = None,
                                         normalized = False,
                                         wkwargs = None).to(device)

In parallel, I am pass the filtered x through the autoencoder g_\psi as follows:

def run_batch_recon(mode,x,spec,net,loss_func,device):
    
    # compute magnitude spectrogram of the filtered signal
    
    x = torchaudio.functional.complex_norm(spec(x))
    
    # scale each magnitude spectrogram in the batch to the interval [0,1]
    
    scale_factor = x.amax(dim=(2,3))[(..., ) + (None, ) * 2]
    x_scaled = x / scale_factor
    
    # with torch.set_grad_enabled(mode == 'train'):
        
    # compute reconstruction of the scaled magnitude spectrogram
    
    x_hat = net(x_scaled)
    
    # unscale magnitude spectrogram back to normal values
    
    x_hat = x_hat * scale_factor
    
    # compute reconstruction loss
    
    recon_loss = loss_func(x_hat,x)
    
    return recon_loss

Where loss_func is:

def recon_loss_func(x_hat,x,alpha = 1):
    
    # spectral convergence
    
    num = (x - x_hat).pow(2).sum(dim=(2,3)).sqrt().squeeze()
    den = x.pow(2).sum(dim=(2,3)).sqrt().squeeze()
    spec_conv = torch.div(num,den)
    
    # log-scale STFT magnitude loss
    
    eps = 1e-10
    num = (torch.log(x + eps) - torch.log(x_hat + eps)).abs().sum(dim=2).amax(dim=2).squeeze()
    den = torch.log(x + eps).abs().sum(dim=2).amax(dim=2).squeeze()
    log_loss = torch.div(num,den)
    
    return torch.mean(spec_conv + alpha * log_loss)

Objective

I am trying to choose the parameters \theta, \phi, and \psi to minimize the following objective function:

Where L_y is the binary cross entropy loss of predicting y_hat not equal to y (NLL in the code above) and L_x is a modified RMSE loss between x and its reconstruction x_hat (recon_loss_func above). Essentially, I am trying to minimize the binary cross entropy loss while maximizing the MSE reconstruction loss. The purpose of this is to choose the filter parameters \theta that balance these two objectives.

Here is how I implement this objective function:

def run_batch(batch,mode,transforms,filt,nets,loss_funcs,lambda_adv,
              optimizer,device):
    
    # unpack everything
    
    x,labels = batch
    disc_net,recon_net = nets
    disc_loss_func, recon_loss_func = loss_funcs
    
    # move to GPU if available
    
    x = x.to(device)
    labels = labels.to(device).type_as(x) # needed for NLL
    
    with torch.set_grad_enabled(mode == 'train'):
        
        # filter each signal
        
        x = filt(x)
        
        # compute classification (negative log-likelihood (NLL)) loss and
        # number of correct predictions
        
        NLL, true_preds = run_batch_disc(mode,x,labels,transforms,disc_net,
                                         disc_loss_func,device)
        
        # extract speech signals from x and compute reconstruction loss
        
        where_speech = torch.nonzero(labels == 0, as_tuple = True)
        recon_loss = run_batch_recon(mode,x[where_speech],transforms[2],
                                     recon_net,recon_loss_func,device)
        
        # compute adversarial loss
        
        adv_loss = NLL - lambda_adv * recon_loss
        
        if mode == 'train':
        
            # compute gradient of adversarial loss with respect to filter,
            # discriminator network, and reconstruction network parameters
            
            adv_loss.backward()
            
            # update parameters using the gradient descent update rule
            
            optimizer.step()
            
            # zero the accumulated parameter gradients
            
            optimizer.zero_grad()
    
    return NLL.item(),true_preds,recon_loss.item(),adv_loss.item()

The Problem

As I am trying to implement this, I keep getting all NaN’s in the gradients of the filter parameters \theta once I call .backward() on the objective function shown above, while the gradients of the \phi and \psi of f and g respectively do not contain NaN’s.

I have tried to set

torch.autograd.set_detect_anomaly(True)

And I get the following RuntimeError and traceback:

RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.
C:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py:130: UserWarning: Error detected in PowBackward0. Traceback of forward call that caused the error:
  File "c:\users\mahmoud talaat\documents\ncsu\semesters\fall2020\ece695_masters_research\adversarial_train_val.py", line 323, in <module>
    train_stats = run_epoch('train',
  File "c:\users\mahmoud talaat\documents\ncsu\semesters\fall2020\ece695_masters_research\adversarial_train_val.py", line 151, in run_epoch
    stats = run_batch(batch,mode,transforms,filt,nets,loss_funcs,
  File "c:\users\mahmoud talaat\documents\ncsu\semesters\fall2020\ece695_masters_research\adversarial_train_val.py", line 94, in run_batch
    NLL, true_preds = run_batch_disc(mode,x,labels,transforms,disc_net,
  File "c:\users\mahmoud talaat\documents\ncsu\semesters\fall2020\ece695_masters_research\adversarial_train_val.py", line 19, in run_batch_disc
    x = torchaudio.functional.complex_norm(x)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torchaudio\functional.py", line 424, in complex_norm
    return complex_tensor.pow(2.).sum(-1).pow(0.5 * power)
 (Triggered internally at  ..\torch\csrc\autograd\python_anomaly_mode.cpp:104.)
  Variable._execution_engine.run_backward(

The part where this RuntimeError happens is here:

def run_batch_disc(mode,x,labels,transforms,net,loss_func,device):
    
    log, mel, spec = transforms
    
    # compute logarithmic Mel-scale magnitude spectrograms of the filtered
    # signals
    
    x = spec(x)
    x = torchaudio.functional.complex_norm(x)
    x = mel(x)
    x = log(x)

I have made sure that x = spec(x) is >= 0, but I still get this problem.

I have noticed that if I remove the autoencoder g_\psi, and just pass x through h and f and call backward on the binary cross entropy loss alone, there are no NaN’s in the gradients of the filter parameters.

Could you make sure that the log operation returns valid outputs, e.g. that x is not zero?
Given that spec(x) can be zero, would it also be possible that complex_norm and mel return zeros?
If so, then torch.log(torch.zeros(1)) would return tensor([-inf]), which might raise this issue in the backward pass.

This solved the problem!

Just needed to change this:

x = spec(x)
x = torchaudio.functional.complex_norm(x)
x = mel(x)
x = log(x)

To this:

eps = 1e-9
x = spec(x)
x = torchaudio.functional.complex_norm(x + eps)
x = mel(x)
x = log(x)

I wasn’t sure if log(x) will cause a problem if x contains zeroes, so I added eps to it as well:

eps = 1e-9
x = spec(x)
x = torchaudio.functional.complex_norm(x + eps)
x = mel(x)
x = log(x + eps)