I am using the pghi algorithm for phase reconstruction. As my framework of choice is PyTorch, I have tried to re-implement the algorithm from pypghi found here.
For small matrices the algorithms converge to similar solutions, however for large matrices the PyTorch algorithm never halts and has a memory leak problem. I have tried disabling gradients no_grad
but that has no effect.
Orginal pypghi algorithm
import heapq
import numpy as np
def ppghi(X, win_length=2048, hop_length=512, gamma=None, tol=1e-6):
if gamma is None:
gamma = 2 * np.pi * ((-win_length**2 / (8 * np.log(0.01)))**.5)**2
spectrogram = X.copy()
abstol = np.array([1e-10], dtype=spectrogram.dtype)[
0] # if abstol is not the same type as spectrogram then casting occurs
phase = np.zeros_like(spectrogram)
max_val = np.amax(spectrogram) # Find maximum value to start integration
max_x, max_y = np.where(spectrogram == max_val)
max_pos = max_x[0], max_y[0]
if max_val <= abstol: # Avoid integrating the phase for the spectrogram of a silent signal
return phase
M2 = spectrogram.shape[0]
N = spectrogram.shape[1]
fmul = gamma / (hop_length * win_length)
Y = np.empty((spectrogram.shape[0] + 2, spectrogram.shape[1] + 2),
dtype=spectrogram.dtype)
Y[1:-1, 1:-1] = np.log(spectrogram + 1e-50)
Y[0, :] = Y[1, :]
Y[-1, :] = Y[-2, :]
Y[:, 0] = Y[:, 1]
Y[:, -1] = Y[:, -2]
dxdw = (Y[1:-1, 2:] - Y[1:-1, :-2]) / 2
dxdt = (Y[2:, 1:-1] - Y[:-2, 1:-1]) / 2
fgradw = dxdw / fmul + (2 * np.pi * hop_length /
win_length) * np.arange(int(win_length / 2) + 1)
tgradw = -fmul * dxdt + np.pi
magnitude_heap = [(-max_val, max_pos)
] # Numba requires heap to be initialized with content
spectrogram[max_pos] = abstol
small_x, small_y = np.where(spectrogram < max_val * tol)
for x, y in zip(small_x, small_y):
spectrogram[x, y] = abstol # Do not integrate over silence
while max_val > abstol:
while len(
magnitude_heap
) > 0: # Integrate around maximum value until reaching silence
max_val, max_pos = heapq.heappop(magnitude_heap)
col = max_pos[0]
row = max_pos[1]
#Spread to 4 direct neighbors
N_pos = col + 1, row
S_pos = col - 1, row
E_pos = col, row + 1
W_pos = col, row - 1
if max_pos[0] < M2 - 1 and spectrogram[N_pos] > abstol:
phase[N_pos] = phase[max_pos] + (fgradw[max_pos] +
fgradw[N_pos]) / 2
heapq.heappush(magnitude_heap, (-spectrogram[N_pos], N_pos))
spectrogram[N_pos] = abstol
if max_pos[0] > 0 and spectrogram[S_pos] > abstol:
phase[S_pos] = phase[max_pos] - (fgradw[max_pos] +
fgradw[S_pos]) / 2
heapq.heappush(magnitude_heap, (-spectrogram[S_pos], S_pos))
spectrogram[S_pos] = abstol
if max_pos[1] < N - 1 and spectrogram[E_pos] > abstol:
phase[E_pos] = phase[max_pos] + (tgradw[max_pos] +
tgradw[E_pos]) / 2
heapq.heappush(magnitude_heap, (-spectrogram[E_pos], E_pos))
spectrogram[E_pos] = abstol
if max_pos[1] > 0 and spectrogram[W_pos] > abstol:
phase[W_pos] = phase[max_pos] - (tgradw[max_pos] +
tgradw[W_pos]) / 2
heapq.heappush(magnitude_heap, (-spectrogram[W_pos], W_pos))
spectrogram[W_pos] = abstol
max_val = np.amax(
spectrogram) # Find new maximum value to start integration
max_x, max_y = np.where(spectrogram == max_val)
max_pos = max_x[0], max_y[0]
heapq.heappush(magnitude_heap, (-max_val, max_pos))
spectrogram[max_pos] = abstol
return phase
PyTorch Implementation
import heapq
from typing import Optional
import numpy as np
import torch
from einops import rearrange
from torch import Tensor
from torch.nn import functional as F
def centered_finite_differences(x: Tensor,
n: int = 1,
dim: int = -1) -> Tensor:
# Make `dim` to be the last dimension
x = x.transpose(-1, dim)
for _ in range(n):
# Pad the tensor to ensure same number of elements as input after the differences
x = F.pad(x, (1, 1), mode="replicate")
# Compute the centered finite differences: f'(x) = (f(x_{i-1}) - (f_x{i+1}))/2h; h=1
x = (x[..., 2:] - x[..., :-2]) / 2
return x.transpose(-1, dim)
def pghi(specgram: Tensor,
win_length: int = 2048,
hop_length: int = 1024,
gamma: Optional[float] = None,
tol: float = 1e-6) -> Tensor:
# Set default values for `None` arguments
gamma = gamma or (2 * np.pi * np.square(
np.sqrt(-np.square(win_length) / (8 * np.log(0.01)))))
# Detach computation graph and keep a copy of the specgram
specgram = specgram.detach().clone()
# Initalize values for the phase integration algorithm
abstol = torch.tensor(1e-10, device=specgram.device, dtype=specgram.dtype)
phase = torch.zeros_like(specgram)
max_val = torch.max(specgram)
max_pos = torch.where(specgram == max_val)
# Avoid integrating the phase for a silent signal
if max_val <= abstol:
return phase
# Compute time and frequency derivatives of the log magnitude specgram
dtime = centered_finite_differences(torch.log(specgram + 1e-50), dim=-1)
dfreq = centered_finite_differences(torch.log(specgram + 1e-50), dim=-2)
# Scale the time and frequency derivatives
scaling = gamma / (hop_length * win_length)
gradtime = -scaling * dtime + torch.pi
gradfreq = dfreq / scaling + (
2 * torch.pi * hop_length / win_length) * torch.arange(
win_length // 2 + 1, device=specgram.device).reshape(-1, 1)
# Avoid integrating over silence
specgram[specgram < max_val * tol] = abstol
# Initialize the heap for integration with the max values
magnitude_heap = [(-max_val, max_pos)]
freq_channels, time_frames = specgram.shape
while max_val > abstol:
# Integrate around the maximum value until reaching silence
while len(magnitude_heap) > 0:
# Remove the top of the heap
max_val, max_pos = heapq.heappop(magnitude_heap)
# Construct positions for the neighbors
col, row = max_pos
N_pos = col + 1, row
S_pos = col - 1, row
E_pos = col, row + 1
W_pos = col, row - 1
# Integrate around the north neighbor
if max_pos[0] < freq_channels - 1 and specgram[N_pos] > abstol:
phase[N_pos] = phase[max_pos] + (gradtime[max_pos] +
gradtime[N_pos]) / 2
heapq.heappush(magnitude_heap, (-specgram[N_pos], N_pos))
specgram[N_pos] = abstol
# Integrate around the south neighbor
if max_pos[0] > 0 and specgram[S_pos] > abstol:
phase[S_pos] = phase[max_pos] - (gradtime[max_pos] +
gradtime[S_pos]) / 2
heapq.heappush(magnitude_heap, (-specgram[S_pos], S_pos))
specgram[S_pos] = abstol
# Integrate around the east neighbor
if max_pos[1] < time_frames - 1 and specgram[E_pos] > abstol:
phase[E_pos] = phase[max_pos] + (gradfreq[max_pos] +
gradfreq[E_pos]) / 2
heapq.heappush(magnitude_heap, (-specgram[E_pos], E_pos))
specgram[E_pos] = abstol
# Integrate around the west neighbor
if max_pos[1] > 0 and specgram[W_pos] > abstol:
phase[W_pos] = phase[max_pos] - (gradfreq[max_pos] +
gradfreq[W_pos]) / 2
heapq.heappush(magnitude_heap, (-specgram[W_pos], W_pos))
specgram[W_pos] = abstol
# Find the new maximum value to start the integration
max_val = torch.max(specgram)
max_pos = torch.where(specgram == max_val)
heapq.heappush(magnitude_heap, (-max_val, max_pos))
specgram[max_pos] = abstol
return phase
Check for a small matrix
x = torch.randn((3, 6)).abs()
print("Input:\n", x)
phase = pghi(x, win_length=4, hop_length=1)
print("PyTorch PGHI:\n", phase)
pphase = ppghi(x.T.numpy(), win_length=4, hop_length=1)
pphase = torch.from_numpy(pphase).T
print("PyPGHI:\n", pphase)
print("Exact solution:", torch.allclose(phase, pphase)) # True