# NaN's in gradients due to multi-objective loss function

## Context

I’m currently trying to implement the following architecture:

Where x is an audio signal, h_\theta is a linear filter, f_\phi is a network that predicts one of two classes, either y_hat = 0 or y_hat = 1, and g_\psi is an autoencoder that attempts to reconstruct x as x_\hat.

Here is how I pass a batch of x’s through the filter h_\theta:

x = filt(x)


Where filt is a nn.Module object with trainable parameters:

filt = Filter(bias = True, kernel_size = 201, identity = False).to(device)


After passing x through the filter h_\theta, here is how I pass it through the network f_\phi:

def run_batch_disc(mode,x,labels,transforms,net,loss_func,device):

log, mel, spec = transforms

# compute logarithmic Mel-scale magnitude spectrograms of the filtered
# signals

x = spec(x)
x = torchaudio.functional.complex_norm(x)
x = mel(x)
x = log(x)

# logits must have the same shape as labels

logits = net(x).squeeze(dim = 1)

# compute negative log-likelihood (NLL) using logits

NLL = loss_func(logits,labels)

# record predictions. since sigmoid(0) = 0.5, then negative values
# correspond to class 0 and positive values correspond to class 1

preds = logits > 0

# record correct predictions

true_preds = torch.sum(preds == labels)

return NLL,true_preds.item()


Where log, mel, and spec are:

log = torchaudio.transforms.AmplitudeToDB(stype = 'magnitude',
top_db = 80).to(device)

# settings for spectrograms

win_length_sec = 0.012
win_length = int(sample_rate * win_length_sec)
# need 50% overlap to satisfy constant-overlap-add constraint to allow
# for perfect reconstruction using inverse STFT
hop_length = int(sample_rate * win_length_sec / 2) # used to be 64
n_mels = 128
n_fft = 512 # used to be 1024

mel = torchaudio.transforms.MelScale(n_mels = n_mels,
sample_rate = sample_rate,
f_min = 0.0,
f_max = None,
n_stft = n_fft // 2 + 1).to(device)

spec = torchaudio.transforms.Spectrogram(n_fft = n_fft,
win_length = win_length,
hop_length = hop_length,
window_fn = torch.hann_window,
power = None,
normalized = False,
wkwargs = None).to(device)


In parallel, I am pass the filtered x through the autoencoder g_\psi as follows:

def run_batch_recon(mode,x,spec,net,loss_func,device):

# compute magnitude spectrogram of the filtered signal

x = torchaudio.functional.complex_norm(spec(x))

# scale each magnitude spectrogram in the batch to the interval [0,1]

scale_factor = x.amax(dim=(2,3))[(..., ) + (None, ) * 2]
x_scaled = x / scale_factor

# compute reconstruction of the scaled magnitude spectrogram

x_hat = net(x_scaled)

# unscale magnitude spectrogram back to normal values

x_hat = x_hat * scale_factor

# compute reconstruction loss

recon_loss = loss_func(x_hat,x)

return recon_loss


Where loss_func is:

def recon_loss_func(x_hat,x,alpha = 1):

# spectral convergence

num = (x - x_hat).pow(2).sum(dim=(2,3)).sqrt().squeeze()
den = x.pow(2).sum(dim=(2,3)).sqrt().squeeze()
spec_conv = torch.div(num,den)

# log-scale STFT magnitude loss

eps = 1e-10
num = (torch.log(x + eps) - torch.log(x_hat + eps)).abs().sum(dim=2).amax(dim=2).squeeze()
den = torch.log(x + eps).abs().sum(dim=2).amax(dim=2).squeeze()
log_loss = torch.div(num,den)



## Objective

I am trying to choose the parameters \theta, \phi, and \psi to minimize the following objective function:

Where L_y is the binary cross entropy loss of predicting y_hat not equal to y (NLL in the code above) and L_x is a modified RMSE loss between x and its reconstruction x_hat (recon_loss_func above). Essentially, I am trying to minimize the binary cross entropy loss while maximizing the MSE reconstruction loss. The purpose of this is to choose the filter parameters \theta that balance these two objectives.

Here is how I implement this objective function:

def run_batch(batch,mode,transforms,filt,nets,loss_funcs,lambda_adv,
optimizer,device):

# unpack everything

x,labels = batch
disc_net,recon_net = nets
disc_loss_func, recon_loss_func = loss_funcs

# move to GPU if available

x = x.to(device)
labels = labels.to(device).type_as(x) # needed for NLL

# filter each signal

x = filt(x)

# compute classification (negative log-likelihood (NLL)) loss and
# number of correct predictions

NLL, true_preds = run_batch_disc(mode,x,labels,transforms,disc_net,
disc_loss_func,device)

# extract speech signals from x and compute reconstruction loss

where_speech = torch.nonzero(labels == 0, as_tuple = True)
recon_loss = run_batch_recon(mode,x[where_speech],transforms[2],
recon_net,recon_loss_func,device)

if mode == 'train':

# discriminator network, and reconstruction network parameters

# update parameters using the gradient descent update rule

optimizer.step()

# zero the accumulated parameter gradients



## The Problem

As I am trying to implement this, I keep getting all NaN’s in the gradients of the filter parameters \theta once I call .backward() on the objective function shown above, while the gradients of the \phi and \psi of f and g respectively do not contain NaN’s.

I have tried to set

torch.autograd.set_detect_anomaly(True)


And I get the following RuntimeError and traceback:

RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.
C:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py:130: UserWarning: Error detected in PowBackward0. Traceback of forward call that caused the error:
File "c:\users\mahmoud talaat\documents\ncsu\semesters\fall2020\ece695_masters_research\adversarial_train_val.py", line 323, in <module>
train_stats = run_epoch('train',
File "c:\users\mahmoud talaat\documents\ncsu\semesters\fall2020\ece695_masters_research\adversarial_train_val.py", line 151, in run_epoch
stats = run_batch(batch,mode,transforms,filt,nets,loss_funcs,
File "c:\users\mahmoud talaat\documents\ncsu\semesters\fall2020\ece695_masters_research\adversarial_train_val.py", line 94, in run_batch
NLL, true_preds = run_batch_disc(mode,x,labels,transforms,disc_net,
File "c:\users\mahmoud talaat\documents\ncsu\semesters\fall2020\ece695_masters_research\adversarial_train_val.py", line 19, in run_batch_disc
x = torchaudio.functional.complex_norm(x)
File "C:\ProgramData\Anaconda3\lib\site-packages\torchaudio\functional.py", line 424, in complex_norm
return complex_tensor.pow(2.).sum(-1).pow(0.5 * power)
Variable._execution_engine.run_backward(


The part where this RuntimeError happens is here:

def run_batch_disc(mode,x,labels,transforms,net,loss_func,device):

log, mel, spec = transforms

# compute logarithmic Mel-scale magnitude spectrograms of the filtered
# signals

x = spec(x)
x = torchaudio.functional.complex_norm(x)
x = mel(x)
x = log(x)


I have made sure that x = spec(x) is >= 0, but I still get this problem.

I have noticed that if I remove the autoencoder g_\psi, and just pass x through h and f and call backward on the binary cross entropy loss alone, there are no NaN’s in the gradients of the filter parameters.

Could you make sure that the log operation returns valid outputs, e.g. that x is not zero?
Given that spec(x) can be zero, would it also be possible that complex_norm and mel return zeros?
If so, then torch.log(torch.zeros(1)) would return tensor([-inf]), which might raise this issue in the backward pass.

This solved the problem!

Just needed to change this:

x = spec(x)
x = torchaudio.functional.complex_norm(x)
x = mel(x)
x = log(x)


To this:

eps = 1e-9
x = spec(x)
x = torchaudio.functional.complex_norm(x + eps)
x = mel(x)
x = log(x)


I wasn’t sure if log(x) will cause a problem if x contains zeroes, so I added eps to it as well:

eps = 1e-9
x = spec(x)
x = torchaudio.functional.complex_norm(x + eps)
x = mel(x)
x = log(x + eps)