Pytorch GPU inference Out Of Memory(OOM)

Hello
I’m stucking with this problem for about a week.
Currently, i’m working on the code of Hifi-GAN official code with my own model.

But the main problem is that
my GPU0 suddenly increases and goes out of memory when the validation process goes on.

I tried to set batch size as 8 and 16 but both results came out same as out of memory…

I would appreciate anyone who can make me free from this problem…Thank you.

Validation

            if steps % a.validation_interval == 0:  # and steps != 0:
                generator.eval()
                torch.cuda.empty_cache()
                val_err_tot = 0
                with torch.no_grad():
                    for j, batch in enumerate(validation_loader):
                        x, y, _, y_mel = batch
                        y_g_hat = generator(x.to(device))  
                        y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
                        y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
                                                      h.hop_size, h.win_size,
                                                      h.fmin, h.fmax_for_loss)
                        val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()

                        if j <= 4:
                            if steps == 0:
                                sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
                                sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)

                            sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
                            y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
                                                         h.sampling_rate, h.hop_size, h.win_size,
                                                         h.fmin, h.fmax)
                            sw.add_figure('generated/y_hat_spec_{}'.format(j),
                                          plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)

                    val_err = val_err_tot / (j+1)
                    sw.add_scalar("validation/mel_spec_error", val_err, steps)

                generator.train()

        steps += 1

    scheduler_g.step()
    scheduler_d.step()