Exploding loss in encoder/decoder model

I’m trying to build a text to speech model in PyTorch using an encoder/decoder architecture on librispeech 100hr dataset. The model essentially takes in text and outputs a mel spectrogram but I’m facing an issue where my loss explodes on the 2nd to 3rd batch irrespective of batch size.

I’m probably doing something stupid in my trainer but I’m not sure what! Does anyone have any insight?

Minimal train.py - gist

model.py - gist

This is the terminal output when trying to run the training file

loss tensor(0.2049, device='cuda:0', grad_fn=<MSELossBackward0>)
loss tensor(8.0709e+17, device='cuda:0', grad_fn=<MSELossBackward0>)
loss tensor(9.1250e+29, device='cuda:0', grad_fn=<MSELossBackward0>)
... [loss will be inf]

Your loss calculation seems to delete the intermediate loss tensor and just uses the one from the last iteration:

        for b in range(BSIZE):
            loss = 0
            batch_ms_pred_list = []
            for i in tqdm(range(audio[b].shape[0])):
                ms_pred, dec_hidden, dec_cell = decoder_model(dec_input, dec_hidden, dec_cell, enc_outputs[b].unsqueeze(0))

                if random.uniform(0, 1) < p_teacher_forcing: # teacher forcing
                    dec_input = audio[b][i].unsqueeze(0).to(device)
                else: # no teacher forcing
                    dec_input = ms_pred

                # append mspred
                batch_ms_pred_list.append(ms_pred)

            # accumulate loss
            batch_ms_pred_list = torch.stack(batch_ms_pred_list).squeeze()
            loss = loss_fn(batch_ms_pred_list, audio[b].to(device))
            print('loss', loss)

Is this on purpose? Also, could you try to overfit a small dataset (e.g. just 10 samples) and check if your model training would still explode?