Hello, I am training my customized TTS model based on StyleTTS2 model. When I train this model for a few steps, the wav output value explodes in some sample in a batch (which I check the input, nothing unusual)
[i.max() for i in output[“y_pred_fake”]]
[tensor(0.1555, device=‘cuda:0’, grad_fn=),
tensor(0.1796, device=‘cuda:0’, grad_fn=),
tensor(1.2186e+20, device=‘cuda:0’, grad_fn=),
tensor(1.9930e+09, device=‘cuda:0’, grad_fn=),
tensor(0.1968, device=‘cuda:0’, grad_fn=),
tensor(0.1774, device=‘cuda:0’, grad_fn=),
tensor(0.1347, device=‘cuda:0’, grad_fn=),
tensor(8.7039, device=‘cuda:0’, grad_fn=)]
this causes mel loss to be NaN
from what I inspected
- the problematic line is this torch.exp from this code which is an ISTFTNet vocoder
def forward(self, x, s, f0):
with torch.no_grad():
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
har_source, noi_source, uv = self.m_source(f0)
har_source = har_source.transpose(1, 2).squeeze(1)
har_spec, har_phase = self.stft.transform(har_source)
har = torch.cat([har_spec, har_phase], dim=1)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x_source = self.noise_convs[i](har)
x_source = self.noise_res[i](x_source, s)
x = self.ups[i](x)
if i == self.num_upsamples - 1:
x = self.reflection_pad(x)
x = x + x_source
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i*self.num_kernels+j](x, s)
else:
xs += self.resblocks[i*self.num_kernels+j](x, s)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :]) # this line makes y explode
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
out = self.stft.inverse(spec, phase)
return out
- gradient value of some module of the model is like 30 to 40 and increasing every steps until predicted y value explode
- input data is normed properly
I tried
-
clamping the “x“ value to not exceed 6 to make it not explode but the value would be capped at 6 all the time and won’t improve
-
do the gradient clipping
accelerator.clip_grad_norm_(model.parameters(), 1.0)scale gradient to make gradient not exceed 1
-
reduce learning rate from 1e-4 to 1-e5
How do I solve this