Hi,
I have been working on a reconstruction of a continious signal for a vision project. Reconstruction is perfect when I do not use any batch information with a singal of lengt 3000 and shape [3000]
, but when I try to wrap the signal into a single batch with shape [1, 3000]
my reconstruction error increase and looks like as if values are scaled by a factor of 0.5. Obviously this should not occur. I have tried examining this in more detail for the past days, but was unable to resolve this issue and am turning to the community.
While I tried to investigate different modes of error like : 1) whether batching is accounted for when computing the reconsturction, 2) trying to look for the where the scaling could occur, or 3) whether we are indexing into the correct values, all lead to a dead end. As such please find my code, for reference below. Please let me know if you spot something that could cause or resolve my problem.
Libaries
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import math
from scipy import linalg as la
from scipy import signal
from scipy import special as ss
device = torch.device("cpu")
Model
def transition(measure, N, **measure_args):
# Laguerre (translated)
if measure == 'lagt':
b = measure_args.get('beta', 1.0)
A = np.eye(N) / 2 - np.tril(np.ones((N, N)))
B = b * np.ones((N, 1))
# Legendre (translated)
elif measure == 'legt':
Q = np.arange(N, dtype=np.float64)
R = (2*Q + 1) ** .5
j, i = np.meshgrid(Q, Q)
A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :]
B = R[:, None]
A = -A
# Legendre (scaled)
elif measure == 'legs':
q = np.arange(N, dtype=np.float64)
col, row = np.meshgrid(q, q)
r = 2 * q + 1
M = -(np.where(row >= col, r, 0) - np.diag(q))
T = np.sqrt(np.diag(2 * q + 1))
A = T @ M @ np.linalg.inv(T)
B = np.diag(T)[:, None]
B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
elif measure == 'fourier':
freqs = np.arange(N//2)
d = np.stack([np.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
A = 2*np.pi*(-np.diag(d, 1) + np.diag(d, -1))
B = np.zeros(N)
B[0::2] = 2
B[0] = 2**.5
A = A - B[:, None] * B[None, :]
# A = A - np.eye(N)
B *= 2**.5
B = B[:, None]
return A, B
class HiPPOScale(nn.Module):
""" Vanilla HiPPO-LegS model (scale invariant instead of time invariant) """
def __init__(self, N, method='legs', max_length=1024, discretization='bilinear'):
"""
max_length: maximum sequence length
"""
super().__init__()
self.N = N
A, B = transition(method, N)
B = B.squeeze(-1)
A_stacked = np.empty((max_length, N, N), dtype=A.dtype)
B_stacked = np.empty((max_length, N), dtype=B.dtype)
for t in range(1, max_length + 1):
At = A / t
Bt = B / t
if discretization == 'forward':
A_stacked[t - 1] = np.eye(N) + At
B_stacked[t - 1] = Bt
elif discretization == 'backward':
A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, np.eye(N), lower=True)
B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, Bt, lower=True)
elif discretization == 'bilinear':
A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At / 2, np.eye(N) + At / 2, lower=True)
B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At / 2, Bt, lower=True)
else: # ZOH
A_stacked[t - 1] = la.expm(A * (math.log(t + 1) - math.log(t)))
B_stacked[t - 1] = la.solve_triangular(A, A_stacked[t - 1] @ B - B, lower=True)
self.register_buffer('A_stacked', torch.Tensor(A_stacked)) # (max_length, N, N)
self.register_buffer('B_stacked', torch.Tensor(B_stacked)) # (max_length, N)
vals = np.linspace(0.0, 1.0, max_length)
self.eval_matrix = torch.Tensor((B[:, None] * ss.eval_legendre(np.arange(N)[:, None], 2 * vals - 1)).T )
def forward(self, inputs, fast=False):
"""
inputs : (length, ...)
output : (length, ..., N) where N is the order of the HiPPO projection
"""
L = inputs.shape[0]
inputs = inputs.unsqueeze(-1)
u = torch.transpose(inputs, 0, -2)
u = u * self.B_stacked[:L]
u = torch.transpose(u, 0, -2) # (length, ..., N)
if fast:
result = unroll.variable_unroll_matrix(self.A_stacked[:L], u)
return result
c = torch.zeros(u.shape[1:]).to(inputs)
cs = []
for t, f in enumerate(inputs):
c = F.linear(c, self.A_stacked[t]) + self.B_stacked[t] * f
cs.append(c)
return torch.stack(cs, dim=0)
def reconstruct(self, c):
a = self.eval_matrix.to(c) @ c.unsqueeze(-1)
return a
Synthetic data generation
def whitesignal(period, dt, freq, rms=0.5, batch_shape=()):
# Produces output signal of length period / dt, band-limited to frequency freq
# Output shape (*batch_shape, period/dt)
if freq is not None and freq < 1. / period:
raise ValueError(f"Make ``{freq=} >= 1. / {period=}`` to produce a non-zero signal",)
nyquist_cutoff = 0.5 / dt
if freq > nyquist_cutoff:
raise ValueError(f"{freq} must not exceed the Nyquist frequency for the given dt ({nyquist_cutoff:0.3f})")
n_coefficients = int(np.ceil(period / dt / 2.))
shape = batch_shape + (n_coefficients + 1,)
sigma = rms * np.sqrt(0.5)
coefficients = 1j * np.random.normal(0., sigma, size=shape)
coefficients[..., -1] = 0.
coefficients += np.random.normal(0., sigma, size=shape)
coefficients[..., 0] = 0.
set_to_zero = np.fft.rfftfreq(2 * n_coefficients, d=dt) > freq
coefficients *= (1-set_to_zero)
power_correction = np.sqrt(1. - np.sum(set_to_zero, dtype=float) / n_coefficients)
if power_correction > 0.: coefficients /= power_correction
coefficients *= np.sqrt(2 * n_coefficients)
signal = np.fft.irfft(coefficients, axis=-1)
signal = signal - signal[..., :1] # Start from 0
return signal
Reconstruction
def reconstruct(T, dt, N, freq, vals, u):
u = torch.tensor(u, dtype=torch.float)
u = u.to(device)
# Linear Time Invariant (LTI) methods x' = Ax + Bu
lti_methods = [
'legt',
]
# Original HiPPO-LegS, which uses time-varying SSM x' = 1/t [ Ax + Bu]
# we call this "linear scale invariant"
lsi_methods = ['legs']
for method in lsi_methods:
hippo = HiPPOScale(N=N, method=method, max_length=int(T/dt)).to(device)
u_hippo = hippo.reconstruct(hippo(u))[-1].cpu()
u_hippo_all = hippo.reconstruct(hippo(u)).cpu()
return u, u_hippo, u_hippo_all, vals`
Visualise signal with no batch information
u = whitesignal(3, 1e-3, 3.0, batch_shape=())
sig, recon, recon_all_t, values = reconstruct(T=3, dt=1e-3, N=64, freq=3.0, vals=vals, u=u)
plt.plot(values [-len(sig):], sig, label='u=target',linewidth=2, dashes=(5, 10))
plt.plot(values [-len(recon):], recon, label='recon', dashes=(5, 1))
plt.legend()