PyTorch: loss.backward() keeps running for days

I’m using the pretrained model from deepspeech_pytorch (GitHub - SeanNaren/deepspeech.pytorch: Speech Recognition using DeepSpeech2.) to do some speech recognition tasks.

However, the loss.backward() keeps running for days in training phase.

The code block of training phase:

adversary_model = AdversaryModel( ... )

adv_trainset = WavDataset(adv_trainset_wav_path_list, model.labels, adv_train_spect_parser)
adv_trainset_dataloader = AudioAdvTrainDataLoader(dataset=adv_trainset, batch_size=FLAGS.batch_size, pin_memory=True, shuffle=True)
model.train()

# Because my task is something like adversary example, 
# all the params I need to train is only the "delta", 
# which is declared in addition as nn.Parameter(), 
# but not the params inner the neural network. 
opt = torch.optim.Adam([adversary_model.delta], lr=FLAGS.adv_lr)

adversary_model.to(FLAGS.device)
adversary_model.model.to(FLAGS.device)

for ep in range(FLAGS.adv_train_epochs):
    print(f'\nTrain epoch = { ep }\n', flush=True)
    adversary_model.train()
    for (batch_idx, batch) in enumerate(adv_trainset_dataloader):
        opt.zero_grad()
        

        # A lot of codes to input data batch into the model and compute loss ...

        loss *= adversary_model.factor_loss
        
        print(f'Adv train epoch = { ep }, batch_idx = { batch_idx }, loss.backward() start ...', flush=True)
        loss.backward()
        print(f'Adv train epoch = { ep }, batch_idx = { batch_idx }, loss.backward() finish ...', flush=True)
        
        print(f'Adv train epoch = { ep }, batch_idx = { batch_idx }, opt.step() start ...', flush=True)
        opt.step()
        print(f'Adv train epoch = { ep }, batch_idx = { batch_idx }, opt.step() finish ...', flush=True)

And the stdout:

Train epoch = 0

Adv train epoch = 0, batch_idx = 0, loss.backward() start ...
Adv train epoch = 0, batch_idx = 0, loss.backward() finish ...
Adv train epoch = 0, batch_idx = 0, opt.step() start ...
Adv train epoch = 0, batch_idx = 0, opt.step() finish ...
Adv train epoch = 0, batch_idx = 1, loss.backward() start ...

Without any warnings or errors, it just keep running for days and never stop …

My current batch_size is 32, And I’ve tried to reduce the batch_size to 16, 8, 4, 2, 1. Then I found out the issue occurred around the 40th sample.

And I found out that it seems like the convolution reverb function makes this problem happen, because when I remove this step, everything works fine:

def convolution_reverb(speech : Tensor, rir : Tensor):
    speech = speech.unsqueeze(0)
    rir = rir.unsqueeze(0)
    
    rir = rir / torch.linalg.vector_norm(rir, ord=2)
    rir = torch.flip(rir, [1])
    
    speech = torch.nn.functional.pad(speech, (rir.shape[1]-1, 0))
    reverbed = torch.nn.functional.conv1d(speech[None, ...], rir[None, ...])[0]

    return reverbed[0]

This convolution_reverb function is called inner the parse_audio phase, which is in the __getitem__ function in my dataset class. This function is to simulate room reverberation for clean audio. The input arg speech is the input audio tensor (time domain), and the rir is the sample in some open-source RIR dataset (also time domain), which is preloaded to CUDA. The code is copied from docs of torchaudio (Audio Data Augmentation — Torchaudio 0.10.0 documentation)

I’ve put all the tensor computation on GPU, and I’ve checked the OS memory (64G in total) and GPU memory (48G in total, 2 GeForce RTX 3090) are both enough. I’ve tried to debug this code but got nothing useful.

And the version information of my codes and environments:

  • OS: Linux 5.4.0 / Ubuntu 18.04
  • Python: 3.7.14
  • torch: 1.10.0+cu111
  • torchaudio: 0.10.0+cu111

Could anyone help me with why this keeps running or any good ways to debug the backward() function?

Thank you so much !!

Could you update PyTorch and all sublibraries to the latest version and check if the code is still hanging, please?

Thank you for your advice! I’ll try it soon.

After update everything, the issue is gone, thank you very much !

New versions:

  • Python 3.8.13
  • pytorch-lightning 2.1.1
  • torch 2.0.1
  • torchaudio 2.0.2

Cool! NIT: the latest PyTorch version is 2.1.0 in case you want to update.