GPU computations weirdly slowed down if calling stft

Hi,
I was training some audiovisual source separation models.
I was computing STFT in the GPU for speeding up the preprocessing. The following code is a wrapper around a nn.Module to apply some preprocessing.

The relevant code is:

    def forward(self, inputs: dict):
        """
        Inputs contains the following keys:
           mixture: the mixture of audios as spectrogram
           sk: skeletons of shape N,C,T,V,M which is batch, channels, temporal, joints, n_people
           video: video
        """

        def numpy(x):
            return x[0].detach().cpu().numpy()

        self.n_sources = 2
        with torch.no_grad():
            llcp_embeddings = inputs['llcp_embedding'].transpose(1, 2)


            srcm = inputs['audio']
            srcs = inputs['audio_acmt']

            # Computing STFT
            spm = self.wav2sp(srcm)  # Spectrogram main BxFxTx2
            sps = self.wav2sp(srcs)  # Spectrogram secondary BxFxTx2


            sources = [spm, sps]
            sp_mix_raw = sum(sources) / self.n_sources
            # Downsampling to save memory
            spm = spm[:, ::2, ...]
            sps = sps[:, ::2, ...]
            sp_mix = sp_mix_raw[:, ::2, ...]  # BxFxTx2



            x = sp_mix.permute(0, 3, 1, 2)

        pred = self.core_forward(audio_input=x.detach().requires_grad_(),
                                 visual_input=llcp_embeddings)
        return pred


    def core_forward(self, *, audio_input, visual_input):
        outx = self.llcp(audio_input, visual_input)
        output = {'mask': outx, 'ind_end_feats': None, 'visual_features': None}
        return output

Where self.wav2sp uses Spectrogram operator from torchaudio and self.core_forward just calls the forward code.

I realised that if I call the code cloning the spectrogram, the speed is boosted.

        pred = self.core_forward(audio_input=x.clone().requires_grad_(),
                                 visual_input=llcp_embeddings)

Profiling results without cloning


Self CPU time total: 8.264s
CUDA time total: 8.264s


Process finished with exit code 0

Profiling results withing cloning.

Self CPU time total: 107.938ms
CUDA time total: 107.809ms


Process finished with exit code 0

Does anyone know why?
I found this occur for any┬┐? (i tried several models and always happens) model

Profiled like this:

if __name__=='__main__':
    USE_W = True
    N = 2
    DEVICE = torch.device('cuda:0')
    # DEVICE=torch.device('cpu')



    if USE_W:
        model = LlcpNet(audio_length=65535, audio_samplerate=16384,
                        n_fft=1022, hop_length=256, n_mel=128, sp_freq_shape=1022 // 2 + 1,

                        video_enabled=False, llcp_enabled=True,
                        skeleton_enabled=False, device=DEVICE).to(DEVICE)
        inputs = {'llcp_embedding': torch.rand(N, 100, 512).to(DEVICE),
                  'audio': torch.rand(N, 65535).to(DEVICE),
                  'audio_acmt': torch.rand(N, 65535).to(DEVICE)}
    else:
        model = Llcp().to(DEVICE)
        inputs = {'input_video': torch.rand(N, 512, 100).to(DEVICE),
                  'input_audio': torch.rand(N, 2, 256, 256).to(DEVICE)}
    # input_audio will be (N,2,256,256)
    # input_video will be of size (N,512,100)


    with profiler.profile(with_stack=True, profile_memory=True, use_cuda=True) as prof:
        output = model(inputs) if USE_W else model(**inputs)
    print(prof.key_averages().table(sort_by='self_cpu_time_total', row_limit=-1))
    print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=-1))

A standalone script to check this can be found at:

This occur at least for RTX 3090, Quadro P6000, Titan V and some other gpus.