Model output explode and causing NaN loss after training for a few steps

Hello, I am training my customized TTS model based on StyleTTS2 model. When I train this model for a few steps, the wav output value explodes in some sample in a batch (which I check the input, nothing unusual)
[i.max() for i in output[“y_pred_fake”]]
[tensor(0.1555, device=‘cuda:0’, grad_fn=),
tensor(0.1796, device=‘cuda:0’, grad_fn=),
tensor(1.2186e+20, device=‘cuda:0’, grad_fn=),
tensor(1.9930e+09, device=‘cuda:0’, grad_fn=),
tensor(0.1968, device=‘cuda:0’, grad_fn=),
tensor(0.1774, device=‘cuda:0’, grad_fn=),
tensor(0.1347, device=‘cuda:0’, grad_fn=),
tensor(8.7039, device=‘cuda:0’, grad_fn=)]

this causes mel loss to be NaN

from what I inspected

  1. the problematic line is this torch.exp from this code which is an ISTFTNet vocoder
    def forward(self, x, s, f0):
        with torch.no_grad():
            f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t

            har_source, noi_source, uv = self.m_source(f0)
            har_source = har_source.transpose(1, 2).squeeze(1)
            har_spec, har_phase = self.stft.transform(har_source)
            har = torch.cat([har_spec, har_phase], dim=1)
        
        for i in range(self.num_upsamples):
            x = F.leaky_relu(x, LRELU_SLOPE)
            x_source = self.noise_convs[i](har)
            x_source = self.noise_res[i](x_source, s)

            x = self.ups[i](x)
            if i == self.num_upsamples - 1:
                x = self.reflection_pad(x)

            x = x + x_source
            xs = None
            for j in range(self.num_kernels):
                if xs is None:
                    xs = self.resblocks[i*self.num_kernels+j](x, s)
                else:
                    xs += self.resblocks[i*self.num_kernels+j](x, s)
            x = xs / self.num_kernels
        x = F.leaky_relu(x)
        x = self.conv_post(x)
        
        spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :]) # this line makes y explode
        phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
        out = self.stft.inverse(spec, phase)
        return out
  1. gradient value of some module of the model is like 30 to 40 and increasing every steps until predicted y value explode
  2. input data is normed properly

I tried

  • clamping the “x“ value to not exceed 6 to make it not explode but the value would be capped at 6 all the time and won’t improve

  • do the gradient clipping

    accelerator.clip_grad_norm_(model.parameters(), 1.0)
    

    scale gradient to make gradient not exceed 1

  • reduce learning rate from 1e-4 to 1-e5

How do I solve this

Hi Schnekk!

This StyleTTS 2 paper seems to say that the StyleTTS 2 model uses generative-adversarial
training techniques. Generative-adversarial networks (GANs) are known to be subject to
training instabilities. (This is not surprising as the generator and discriminator in such models
are working against one another at cross purposes, so it’s plausible that they might push one
another’s training off into problematic directions.)

If your model uses GAN techniques, this could be the source of your problem. There’s a lot
of literature about how to stabilize GAN training (but I don’t have specific recommendations).

In general – not specifically in the context of GANs – I recommend starting with plain-vanilla
SGD (without momentum) as your optimizer and try training with a very small learning rate.
You might also try weight decay as a way to regularize the model weights. If you can train
stably with a small learning rate, you can try increasing it and adding momentum.

(Note that optimizers like Adam can train faster, but also tend to be less stable. If you’re
using something like Adam, definitely switch to SGD and only try going back to Adam
if and when you have training under control.)

It’s sometimes the case that the randomly initialized weights of a model might start out in a
location that is near sources of instability. If you train with a small learning rate for a while,
your model weights might move away from unstable locations, after which you might be able
to continue training with a larger learning rate.

I doubt that the exp() is the source of the problem. It will cause the nan to show up sooner,
but I don’t think that postponing the nan would really address the divergence problem.
Nonetheless, if you can figure out some way to work with spec in log space – that is, don’t
call exp() so that you’ll be working with log (spec), it might be worth experimenting with
such an approach.

Good luck.

K. Frank

1 Like