How to use filtfilt() function?

I’m trying to move from scipy to torchaudio.
Here is my code below:

from torchaudio.functional.filtering import filtfilt
from scipy import signal

bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)

audio = sample_input

print(f"Audio contains nan: {torch.isnan(torch.from_numpy(audio).float().to(torch.float64)).any()}")
print(f"Audio contains inf: {torch.isinf(torch.from_numpy(audio).float().to(torch.float64)).any()}")
print(f"Audio min: {torch.from_numpy(audio).float().to(torch.float64).min()}")
print(f"Audio max: {torch.from_numpy(audio).float().to(torch.float64).max()}")
print(f"Audio mean: {torch.from_numpy(audio).float().to(torch.float64).mean()}")
print(f"Audio shape: {torch.from_numpy(audio).float().to(torch.float64).shape}")

print(f"bh contains nan: {torch.isnan(torch.from_numpy(bh).float().to(torch.float64)).any()}")
print(f"bh contains inf: {torch.isinf(torch.from_numpy(bh).float().to(torch.float64)).any()}")
print(f"bh min: {torch.from_numpy(bh).float().to(torch.float64).min()}")
print(f"bh max: {torch.from_numpy(bh).float().to(torch.float64).max()}")
print(f"bh mean: {torch.from_numpy(bh).float().to(torch.float64).mean()}")
print(f"bh shape: {torch.from_numpy(bh).float().to(torch.float64).shape}")

print(f"ah contains nan: {torch.isnan(torch.from_numpy(ah).float().to(torch.float64)).any()}")
print(f"ah contains inf: {torch.isinf(torch.from_numpy(ah).float().to(torch.float64)).any()}")
print(f"ah min: {torch.from_numpy(ah).float().to(torch.float64).min()}")
print(f"ah max: {torch.from_numpy(ah).float().to(torch.float64).max()}")
print(f"ah mean: {torch.from_numpy(ah).float().to(torch.float64).mean()}")
print(f"ah shape: {torch.from_numpy(ah).float().to(torch.float64).shape}")


audio = filtfilt(
    waveform=torch.from_numpy(audio).float().to(torch.float64),
    a_coeffs=torch.from_numpy(ah).float().to(torch.float64),
    b_coeffs=torch.from_numpy(bh).float().to(torch.float64)
)

print(f"Audio after filtfilt : {audio}")

But actual output is that:

Audio contains nan: False
Audio contains inf: False
Audio min: -0.858154296875
Audio max: 0.8670654296875
Audio mean: 0.00011500650977929034
Audio shape: torch.Size([1149120])
bh contains nan: False
bh contains inf: False
bh min: -9.699606895446777
bh max: 9.699606895446777
bh mean: 0.0
bh shape: torch.Size([6])
ah contains nan: False
ah contains inf: False
ah min: -9.639544486999512
ah max: 9.757863998413086
ah mean: 1.3907750447591147e-07
ah shape: torch.Size([6])
Audio after filtfilt : tensor([nan, nan, nan,  ..., nan, nan, nan], dtype=torch.float64)

Am i using this function in a wrong way?lol😂

By the way,same input in scipy is working fine:

from scipy.signal import filtfilt
audio = sample_input
audio = filtfilt(bh, ah, audio)
print(f"Audio after filtfilt : {audio}")

Output:

Audio after filtfilt : [-2.48281273e-07 -2.66557553e-07 -2.84952686e-07 ...  1.79293785e-17
  4.10740501e-17  6.24618812e-17]

Issue close due to mutiple dtype transformation cause precision lose.