Using torch.stft / torch.istft on PackedSequence data

Will it be possible down the line to have torch.stft or torch.istft support PackedSequence objects?

I have provided a minimum working example here:

from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence


fft_size = 1024
hop_size = 256


batched_audio = torch.rand((3, 50000))
window = torch.hann_window(fft_size)
Y = torch.stft(batched_audio, n_fft=fft_size, hop_length=hop_size,
               window=window)

print(torch.is_tensor(Y))
# Out: True

In the above cell, all the audio is equal length (50000 samples). Below is an example with varying length audio (50000, 40000, and 30000 samples respectively).

batched_audio = [torch.rand(50000), torch.rand(40000), torch.rand(30000)]
lengths = [len(x) for x in batched_audio]
padded_audio = pad_sequence(batched_audio, batch_first=True)
packed_audio = pack_padded_sequence(padded_audio, lengths, batch_first=True)

Y = torch.stft(packed_audio, n_fft=fft_size, hop_length=hop_size,
               window=window)

# Out: AttributeError: 'PackedSequence' object has no attribute 'dim'

This would be really handy to have to batch process variable-length audio; taking the STFT and ISTFT are important operations.

I guess the workaround for this right now would be to pre-compute STFTs of all the waveforms before calling pad_sequence, but I’m trying to reduce the GPU memory footprint of my tests. I thought on-the-fly STFT computation would be more efficient.