A port of the numpy pghi algorithm fails to halt for large tensors but works fine for small tensors

I am using the pghi algorithm for phase reconstruction. As my framework of choice is PyTorch, I have tried to re-implement the algorithm from pypghi found here.

For small matrices the algorithms converge to similar solutions, however for large matrices the PyTorch algorithm never halts and has a memory leak problem. I have tried disabling gradients no_grad but that has no effect.

Orginal pypghi algorithm

import heapq 
import numpy as np 

def ppghi(X, win_length=2048, hop_length=512, gamma=None, tol=1e-6):
    if gamma is None:
        gamma = 2 * np.pi * ((-win_length**2 / (8 * np.log(0.01)))**.5)**2

    spectrogram = X.copy()
    abstol = np.array([1e-10], dtype=spectrogram.dtype)[
        0]  # if abstol is not the same type as spectrogram then casting occurs
    phase = np.zeros_like(spectrogram)
    max_val = np.amax(spectrogram)  # Find maximum value to start integration
    max_x, max_y = np.where(spectrogram == max_val)
    max_pos = max_x[0], max_y[0]

    if max_val <= abstol:  # Avoid integrating the phase for the spectrogram of a silent signal
        return phase

    M2 = spectrogram.shape[0]
    N = spectrogram.shape[1]

    fmul = gamma / (hop_length * win_length)

    Y = np.empty((spectrogram.shape[0] + 2, spectrogram.shape[1] + 2),
                 dtype=spectrogram.dtype)
    Y[1:-1, 1:-1] = np.log(spectrogram + 1e-50)
    Y[0, :] = Y[1, :]
    Y[-1, :] = Y[-2, :]
    Y[:, 0] = Y[:, 1]
    Y[:, -1] = Y[:, -2]
    dxdw = (Y[1:-1, 2:] - Y[1:-1, :-2]) / 2
    dxdt = (Y[2:, 1:-1] - Y[:-2, 1:-1]) / 2

    fgradw = dxdw / fmul + (2 * np.pi * hop_length /
                            win_length) * np.arange(int(win_length / 2) + 1)
    tgradw = -fmul * dxdt + np.pi

    magnitude_heap = [(-max_val, max_pos)
                      ]  # Numba requires heap to be initialized with content
    spectrogram[max_pos] = abstol

    small_x, small_y = np.where(spectrogram < max_val * tol)
    for x, y in zip(small_x, small_y):
        spectrogram[x, y] = abstol  # Do not integrate over silence

    while max_val > abstol:
        while len(
                magnitude_heap
        ) > 0:  # Integrate around maximum value until reaching silence
            max_val, max_pos = heapq.heappop(magnitude_heap)

            col = max_pos[0]
            row = max_pos[1]

            #Spread to 4 direct neighbors
            N_pos = col + 1, row
            S_pos = col - 1, row
            E_pos = col, row + 1
            W_pos = col, row - 1

            if max_pos[0] < M2 - 1 and spectrogram[N_pos] > abstol:
                phase[N_pos] = phase[max_pos] + (fgradw[max_pos] +
                                                 fgradw[N_pos]) / 2
                heapq.heappush(magnitude_heap, (-spectrogram[N_pos], N_pos))
                spectrogram[N_pos] = abstol

            if max_pos[0] > 0 and spectrogram[S_pos] > abstol:
                phase[S_pos] = phase[max_pos] - (fgradw[max_pos] +
                                                 fgradw[S_pos]) / 2
                heapq.heappush(magnitude_heap, (-spectrogram[S_pos], S_pos))
                spectrogram[S_pos] = abstol

            if max_pos[1] < N - 1 and spectrogram[E_pos] > abstol:
                phase[E_pos] = phase[max_pos] + (tgradw[max_pos] +
                                                 tgradw[E_pos]) / 2
                heapq.heappush(magnitude_heap, (-spectrogram[E_pos], E_pos))
                spectrogram[E_pos] = abstol

            if max_pos[1] > 0 and spectrogram[W_pos] > abstol:
                phase[W_pos] = phase[max_pos] - (tgradw[max_pos] +
                                                 tgradw[W_pos]) / 2
                heapq.heappush(magnitude_heap, (-spectrogram[W_pos], W_pos))
                spectrogram[W_pos] = abstol

        max_val = np.amax(
            spectrogram)  # Find new maximum value to start integration
        max_x, max_y = np.where(spectrogram == max_val)
        max_pos = max_x[0], max_y[0]
        heapq.heappush(magnitude_heap, (-max_val, max_pos))
        spectrogram[max_pos] = abstol

    return phase

PyTorch Implementation

import heapq
from typing import Optional

import numpy as np
import torch
from einops import rearrange
from torch import Tensor
from torch.nn import functional as F


def centered_finite_differences(x: Tensor,
                                n: int = 1,
                                dim: int = -1) -> Tensor:
    # Make `dim` to be the last dimension
    x = x.transpose(-1, dim)

    for _ in range(n):
        # Pad the tensor to ensure same number of elements as input after the differences
        x = F.pad(x, (1, 1), mode="replicate")

        # Compute the centered finite differences: f'(x) = (f(x_{i-1}) - (f_x{i+1}))/2h; h=1
        x = (x[..., 2:] - x[..., :-2]) / 2

    return x.transpose(-1, dim)


def pghi(specgram: Tensor,
         win_length: int = 2048,
         hop_length: int = 1024,
         gamma: Optional[float] = None,
         tol: float = 1e-6) -> Tensor:
    # Set default values for `None` arguments
    gamma = gamma or (2 * np.pi * np.square(
        np.sqrt(-np.square(win_length) / (8 * np.log(0.01)))))

    # Detach computation graph and keep a copy of the specgram
    specgram = specgram.detach().clone()

    # Initalize values for the phase integration algorithm
    abstol = torch.tensor(1e-10, device=specgram.device, dtype=specgram.dtype)
    phase = torch.zeros_like(specgram)
    max_val = torch.max(specgram)
    max_pos = torch.where(specgram == max_val)

    # Avoid integrating the phase for a silent signal
    if max_val <= abstol:
        return phase

    # Compute time and frequency derivatives of the log magnitude specgram
    dtime = centered_finite_differences(torch.log(specgram + 1e-50), dim=-1)
    dfreq = centered_finite_differences(torch.log(specgram + 1e-50), dim=-2)

    # Scale the time and frequency derivatives
    scaling = gamma / (hop_length * win_length)
    gradtime = -scaling * dtime + torch.pi
    gradfreq = dfreq / scaling + (
        2 * torch.pi * hop_length / win_length) * torch.arange(
            win_length // 2 + 1, device=specgram.device).reshape(-1, 1)

    # Avoid integrating over silence
    specgram[specgram < max_val * tol] = abstol

    # Initialize the heap for integration with the max values
    magnitude_heap = [(-max_val, max_pos)]
    freq_channels, time_frames = specgram.shape

    while max_val > abstol:
        # Integrate around the maximum value until reaching silence
        while len(magnitude_heap) > 0:
            # Remove the top of the heap
            max_val, max_pos = heapq.heappop(magnitude_heap)

            # Construct positions for the neighbors
            col, row = max_pos
            N_pos = col + 1, row
            S_pos = col - 1, row
            E_pos = col, row + 1
            W_pos = col, row - 1

            # Integrate around the north neighbor
            if max_pos[0] < freq_channels - 1 and specgram[N_pos] > abstol:
                phase[N_pos] = phase[max_pos] + (gradtime[max_pos] +
                                                 gradtime[N_pos]) / 2
                heapq.heappush(magnitude_heap, (-specgram[N_pos], N_pos))
                specgram[N_pos] = abstol

            # Integrate around the south neighbor
            if max_pos[0] > 0 and specgram[S_pos] > abstol:
                phase[S_pos] = phase[max_pos] - (gradtime[max_pos] +
                                                 gradtime[S_pos]) / 2
                heapq.heappush(magnitude_heap, (-specgram[S_pos], S_pos))
                specgram[S_pos] = abstol

            # Integrate around the east neighbor
            if max_pos[1] < time_frames - 1 and specgram[E_pos] > abstol:
                phase[E_pos] = phase[max_pos] + (gradfreq[max_pos] +
                                                 gradfreq[E_pos]) / 2
                heapq.heappush(magnitude_heap, (-specgram[E_pos], E_pos))
                specgram[E_pos] = abstol

            # Integrate around the west neighbor
            if max_pos[1] > 0 and specgram[W_pos] > abstol:
                phase[W_pos] = phase[max_pos] - (gradfreq[max_pos] +
                                                 gradfreq[W_pos]) / 2
                heapq.heappush(magnitude_heap, (-specgram[W_pos], W_pos))
                specgram[W_pos] = abstol

        # Find the new maximum value to start the integration
        max_val = torch.max(specgram)
        max_pos = torch.where(specgram == max_val)
        heapq.heappush(magnitude_heap, (-max_val, max_pos))
        specgram[max_pos] = abstol

    return phase

Check for a small matrix

x = torch.randn((3, 6)).abs()
print("Input:\n", x)
phase = pghi(x, win_length=4, hop_length=1)
print("PyTorch PGHI:\n", phase)
pphase = ppghi(x.T.numpy(), win_length=4, hop_length=1)
pphase = torch.from_numpy(pphase).T
print("PyPGHI:\n", pphase)
print("Exact solution:", torch.allclose(phase, pphase)) # True