CTCLoss gives None gradients

I’m trying to build a simple ASR model using the LibriSpeech dataset built into torchvision, but I seem to have a problem with the CTCLoss function.
When I try to use nn.CTCLoss I get None gradient for my loss.
I’m using pytorch 1.4.0 and torchaudio 0.4.0.
I get the following output from running ./train.py:

Using device: cuda
  0%|                                                                                                                                                                               | 0/676 [00:00<?, ?it/s]Training epoch: 1
Loss: 3228.21142578125:   0%|                                                                                                                                                       | 0/676 [00:01<?, ?it/s]
None
Loss: 7.94204209823249e+22:   0%|2                                                                                                                                          | 1/676 [00:02<18:37,  1.65s/it]
None
Loss: nan:   0%|4                                                                                                                                                           | 2/676 [00:03<14:56,  1.33s/it]

By using the code:
train.py

#!/usr/bin/env python3
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.nn import CTCLoss
from torch.utils.data import DataLoader
from torch.optim import Adam, SGD
from torchaudio.datasets import LIBRISPEECH
from models.Basic import Basic
from tqdm import tqdm
import multiprocessing

dictionary = "ABCDEFGHIJKLMNOPQRSTUVWXYZ' "

def text_to_tensor(text, dictionary = dictionary):
  """
  This function will convert a string of text
  to a tensor of character indicies in the given dictionary.
  The indicies will start from 1, as 0 means the blank
  character.
  """
  return torch.tensor([
    dictionary.index(c) + 1 if c in dictionary else 0
    for c in list(text.upper())
  ])

def pad_collate(datapoints):
  waveforms, sample_rates, utterances, speaker_ids, chapter_ids, utterance_ids = zip(*datapoints)
  batch_size = len(datapoints)
  waveform_lengths = torch.tensor([waveform.shape[1] for waveform in waveforms])
  waveforms = pad_sequence([wave.T for wave in waveforms], batch_first=True)
  utterance_lengths = torch.tensor([len(utterance) for utterance in utterances])

  # We convert our label text to tensor of dictionary indicies
  # and reshape the data to (N, S) where N is batch size
  # and S is max target length. Is needed for CTCLoss:
  # https://pytorch.org/docs/stable/nn.html#torch.nn.CTCLoss
  utterances = torch.cat([text_to_tensor(utterance) for utterance in utterances])
  
  return batch_size, waveforms, waveform_lengths, utterances, utterance_lengths

def train(num_epochs=10, batch_size=4, num_workers=multiprocessing.cpu_count()):
  dataset = LIBRISPEECH("../data", "dev-clean", download=True)
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=pad_collate, num_workers=num_workers, pin_memory=True)
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model = Basic(n_classes = len(dictionary) + 1).to(device)
  optimizer = SGD(model.parameters(), lr=0.0001)
  loss_fn = CTCLoss()
  print(f"Using device: {device}")
  
  tqdm_dataloader = tqdm(dataloader)
  
  for epoch in range(num_epochs):
    print(f"Training epoch: {epoch+1}")
    for batch_size, X, X_lengths, y, y_lengths in tqdm_dataloader:
      # First we zero our gradients, to make everything work nicely.
      optimizer.zero_grad()

      X = X.permute(0, 2, 1).to(device)
      X_lengths = X_lengths.to(device)
      y = y.to(device)

      # We predict the outputs using our model
      # and reshape the data to size (T, N, C) where
      # T is target length, N is batch size and C is number of classes.
      # In our case that is the length of the dictionary + 1
      # as we also need one more class for the blank character.
      pred_y = model(X)
      pred_y = pred_y.permute(2, 0, 1)
      pred_y_lengths = model.forward_shape(X_lengths).to(device)
      
      loss = loss_fn(pred_y, y, pred_y_lengths, y_lengths)
      tqdm_dataloader.set_description(f"Loss: {loss.item()}")
      loss.backward()
      print(loss.grad)
      optimizer.step()


if __name__ == "__main__":
  train()

With model
models/Basic.py

import torch.nn as nn
import torch.nn.functional as F

class Basic(nn.Module):
  def __init__(self, n_classes):
    super(Basic, self).__init__()
    self.c1 = nn.Conv1d(1, 128, 32)
    self.c2 = nn.Conv1d(128, n_classes, 64)

  def forward(self, X):
    a = self.c1(X)
    b = F.relu(a)
    c = self.c2(b)
    return F.log_softmax(c, dim=1)

  def forward_shape(self, lengths):
    return lengths - 32 - 64 + 1 + 1

Your loss seems to be exploding. Could you try to lower the learning rate or use an adaptive optimizer such as Adam?

Thanks! - I think that was the problem!
After changing the model from operating directly on the waveform to instead operate on the MFCC it worked! I think the input lengths when operating on the waveform were just too large compared to the target lengths for the CTC to work properly.