Hi everyone, I use the model to process the signal and then use istft to convert it back to a waveform, the size of the signal before istft is the same, but the length of the signal after conversion does not match with the original signal. I read the discussion https://github.com/pytorch/pytorch/issues/62747 on github, and it seems that the length problem has been solved, but I still have the problem. Here is my code, forgive me if the layout is a bit messy.
def main():
sample_rate = 44100
device = torch.device('cuda:0')
model_path='/data2/Will/DNN-based_source_separation-main/DNN-based_source_separation-main/src/TEST/best.pth'
model = MMDenseNet.build_model(model_path) ## the model
config = torch.load(model_path, map_location=lambda storage, loc: storage) ### the checkpoint (dict)
model.load_state_dict(config['state_dict'])
model = model.to(device)
model.eval()
channel = model.in_channels
n_fft = 2048
hop = 1024
window_fn = torch.hann_window(n_fft, periodic=True,device=device)
if os.path.isdir('TT'):
print('exist')
ABS_path = os.path.join(os.getcwd(),'TT')
else:
os.mkdir('TT')
for name in names:
mixture_path = os.path.join(musdb18_root,'test',name,"mixture.wav")
with torch.no_grad():
source, sr = torchaudio.load(mixture_path)
source = source.to(device)
source_duration = source.size(1)/44100
source_stft = torch.stft(source, n_fft=n_fft, hop_length=hop,window=window_fn,return_complex=True)
source_stft = torch.unsqueeze(source_stft, dim=0)
print(source_stft.shape)
source_stft_amp = torch.abs(source_stft)
source_stft_angle = torch.angle(source_stft)
estimated_amp = model(source_stft_amp)
print(estimated_amp.shape)
estimated = estimated_amp*torch.exp(1j*source_stft_angle)
channels = estimated.size()[:-2] ## keep the B,C
estimated = estimated.view(-1, *estimated.size()[-2:])
print(estimated.shape)
estimated_out = torch.istft(estimated, n_fft=n_fft, hop_length=hop, window=window_fn, return_complex=False)
estimated_out = estimated_out.view(*channels, -1).squeeze(0).cpu()
print('TEST',estimated_out.shape,source.shape)
est_path = os.path.join(ABS_path,'{}.wav'.format(name))
torchaudio.save(est_path,estimated_out,sample_rate=sample_rate,channels_first=True,bits_per_sample=16)
Is it because of my pytorch version? My version of pytorch is 1.12.1.