How to accomodate batches when computing,visualizing reconstruction?

Hi,

I have been working on a reconstruction of a continious signal for a vision project. Reconstruction is perfect when I do not use any batch information with a singal of lengt 3000 and shape [3000] , but when I try to wrap the signal into a single batch with shape [1, 3000] my reconstruction error increase and looks like as if values are scaled by a factor of 0.5. Obviously this should not occur. I have tried examining this in more detail for the past days, but was unable to resolve this issue and am turning to the community.

While I tried to investigate different modes of error like : 1) whether batching is accounted for when computing the reconsturction, 2) trying to look for the where the scaling could occur, or 3) whether we are indexing into the correct values, all lead to a dead end. As such please find my code, for reference below. Please let me know if you spot something that could cause or resolve my problem.

Libaries

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import math
from scipy import linalg as la
from scipy import signal
from scipy import special as ss
device = torch.device("cpu")

Model

def transition(measure, N, **measure_args):
    # Laguerre (translated)
    if measure == 'lagt':
        b = measure_args.get('beta', 1.0)
        A = np.eye(N) / 2 - np.tril(np.ones((N, N)))
        B = b * np.ones((N, 1))
    # Legendre (translated)
    elif measure == 'legt':
        Q = np.arange(N, dtype=np.float64)
        R = (2*Q + 1) ** .5
        j, i = np.meshgrid(Q, Q)
        A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :]
        B = R[:, None]
        A = -A
    # Legendre (scaled)
    elif measure == 'legs':
        q = np.arange(N, dtype=np.float64)
        col, row = np.meshgrid(q, q)
        r = 2 * q + 1
        M = -(np.where(row >= col, r, 0) - np.diag(q))
        T = np.sqrt(np.diag(2 * q + 1))
        A = T @ M @ np.linalg.inv(T)
        B = np.diag(T)[:, None]
        B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
    elif measure == 'fourier':
        freqs = np.arange(N//2)
        d = np.stack([np.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
        A = 2*np.pi*(-np.diag(d, 1) + np.diag(d, -1))
        B = np.zeros(N)
        B[0::2] = 2
        B[0] = 2**.5
        A = A - B[:, None] * B[None, :]
        # A = A - np.eye(N)
        B *= 2**.5
        B = B[:, None]

    return A, B

class HiPPOScale(nn.Module):
    """ Vanilla HiPPO-LegS model (scale invariant instead of time invariant) """
    def __init__(self, N, method='legs', max_length=1024, discretization='bilinear'):
        """
        max_length: maximum sequence length
        """
        super().__init__()
        self.N = N
        A, B = transition(method, N)
        B = B.squeeze(-1)
        A_stacked = np.empty((max_length, N, N), dtype=A.dtype)
        B_stacked = np.empty((max_length, N), dtype=B.dtype)
        for t in range(1, max_length + 1):
            At = A / t
            Bt = B / t
            if discretization == 'forward':
                A_stacked[t - 1] = np.eye(N) + At
                B_stacked[t - 1] = Bt
            elif discretization == 'backward':
                A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, np.eye(N), lower=True)
                B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, Bt, lower=True)
            elif discretization == 'bilinear':
                A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At / 2, np.eye(N) + At / 2, lower=True)
                B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At / 2, Bt, lower=True)
            else: # ZOH
                A_stacked[t - 1] = la.expm(A * (math.log(t + 1) - math.log(t)))
                B_stacked[t - 1] = la.solve_triangular(A, A_stacked[t - 1] @ B - B, lower=True)
        self.register_buffer('A_stacked', torch.Tensor(A_stacked)) # (max_length, N, N)
        self.register_buffer('B_stacked', torch.Tensor(B_stacked)) # (max_length, N)

        vals = np.linspace(0.0, 1.0, max_length)
        self.eval_matrix = torch.Tensor((B[:, None] * ss.eval_legendre(np.arange(N)[:, None], 2 * vals - 1)).T  )

    def forward(self, inputs, fast=False):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """

        L = inputs.shape[0]

        inputs = inputs.unsqueeze(-1)
        u = torch.transpose(inputs, 0, -2)
        u = u * self.B_stacked[:L]
        u = torch.transpose(u, 0, -2) # (length, ..., N)

        if fast:
            result = unroll.variable_unroll_matrix(self.A_stacked[:L], u)
            return result

        c = torch.zeros(u.shape[1:]).to(inputs)
        cs = []
        for t, f in enumerate(inputs):
            c = F.linear(c, self.A_stacked[t]) + self.B_stacked[t] * f
            cs.append(c)
        return torch.stack(cs, dim=0)

    def reconstruct(self, c):
        a = self.eval_matrix.to(c) @ c.unsqueeze(-1)
        return a

Synthetic data generation

def whitesignal(period, dt, freq, rms=0.5, batch_shape=()):
    
    # Produces output signal of length period / dt, band-limited to frequency freq
    # Output shape (*batch_shape, period/dt)
     
    if freq is not None and freq < 1. / period:
        raise ValueError(f"Make ``{freq=} >= 1. / {period=}`` to produce a non-zero signal",)

    nyquist_cutoff = 0.5 / dt
    if freq > nyquist_cutoff:
        raise ValueError(f"{freq} must not exceed the Nyquist frequency for the given dt ({nyquist_cutoff:0.3f})")

    n_coefficients = int(np.ceil(period / dt / 2.))
    shape = batch_shape + (n_coefficients + 1,)
    sigma = rms * np.sqrt(0.5)
    coefficients = 1j * np.random.normal(0., sigma, size=shape)
    coefficients[..., -1] = 0.
    coefficients += np.random.normal(0., sigma, size=shape)
    coefficients[..., 0] = 0.

    set_to_zero = np.fft.rfftfreq(2 * n_coefficients, d=dt) > freq
    coefficients *= (1-set_to_zero)
    power_correction = np.sqrt(1. - np.sum(set_to_zero, dtype=float) / n_coefficients)
    if power_correction > 0.: coefficients /= power_correction
    coefficients *= np.sqrt(2 * n_coefficients)
    signal = np.fft.irfft(coefficients, axis=-1)
    signal = signal - signal[..., :1]  # Start from 0
    return signal

Reconstruction

def reconstruct(T, dt, N, freq, vals, u):
    u = torch.tensor(u, dtype=torch.float)
    u = u.to(device)

    # Linear Time Invariant (LTI) methods x' = Ax + Bu
    lti_methods = [
        'legt',
    ]

    # Original HiPPO-LegS, which uses time-varying SSM x' = 1/t [ Ax + Bu]
    # we call this "linear scale invariant"
    lsi_methods = ['legs']
    for method in lsi_methods:
        hippo = HiPPOScale(N=N, method=method, max_length=int(T/dt)).to(device)
        u_hippo = hippo.reconstruct(hippo(u))[-1].cpu()
        u_hippo_all = hippo.reconstruct(hippo(u)).cpu()
    
    return u, u_hippo, u_hippo_all, vals`

Visualise signal with no batch information

u = whitesignal(3, 1e-3, 3.0, batch_shape=())
sig, recon, recon_all_t, values = reconstruct(T=3, dt=1e-3, N=64, freq=3.0, vals=vals, u=u)

plt.plot(values [-len(sig):], sig, label='u=target',linewidth=2, dashes=(5, 10))
plt.plot(values [-len(recon):], recon, label='recon', dashes=(5, 1))
plt.legend()

Visualise signal with 1 batch dimension

plt.plot(values[-len(sig[0]):], sig[0], label='target',linewidth=3,)
plt.plot(values[-len(recon[:,-1,0]):], recon[:,-1,0], linestyle='--', label='recon', dashes=(5, 1))
plt.legend()

Based on your description I would guess your code might be accidentally broadcast some tensors resulting in a wrong computation. Could you check the shapes of the intermediate tensors and see if that’s the case?

Thank you for you answer ptrblck. I have looked into shapes of the intermediate variables and added them for reference below. What would I be interested in ? I thought I have is that the coefficent matrix may not be correctly initalised, shouldnt it be [batchdim, N] [1, 64] instead of [3000, 64] ? This then propogates through to when Coefs c are computed and similarly should be [1, 64] instead of [3000, 64]. But I am not entirely sure as the other matrices like Coefs stacked shape looks correct, Eval mat also looks correct and the recon shape as well. What do you think Patrick ?

Shapes intermediate variables no batch information

-------------Init------------ torch.Size([64])
Coef c shape: torch.Size([64])
-------------Init------------ torch.Size([64])
Coef  c shape: torch.Size([64])
----------------------------------------
----------Reconstruction part:----------
Eval mat shape: torch.Size([3000, 64])
Coefs stacked shape: torch.Size([3000, 64])
Recon shape: torch.Size([3000, 3000, 1])
----------------------------------------
-------------Init------------ torch.Size([64])
Coef c shape: torch.Size([64])
-------------Init------------ torch.Size([64])
Coef  c shape: torch.Size([64])
----------------------------------------
----------Reconstruction part:----------
Eval mat shape: torch.Size([3000, 64])
Coefs stacked shape: torch.Size([3000, 64])
Recon shape: torch.Size([3000, 3000, 1])
----------------------------------------

Shapes intermediate variables 1 batch dimension

-------------Init------------ torch.Size([3000, 64])
Coef c shape: torch.Size([3000, 64])
-------------Init------------ torch.Size([3000, 64])
Coef  c shape: torch.Size([3000, 64])
----------------------------------------
----------Reconstruction part:----------
Eval mat shape: torch.Size([3000, 64])
Coefs stacked shape: torch.Size([1, 3000, 64])
Recon shape: torch.Size([1, 3000, 3000, 1])
----------------------------------------
-------------Init------------ torch.Size([3000, 64])
Coef c shape: torch.Size([3000, 64])
-------------Init------------ torch.Size([3000, 64])
Coef  c shape: torch.Size([3000, 64])
----------------------------------------
----------Reconstruction part:----------
Eval mat shape: torch.Size([3000, 64])
Coefs stacked shape: torch.Size([1, 3000, 64])
Recon shape: torch.Size([1, 3000, 3000, 1])
----------------------------------------

Thanks for the update!
Some shapes look indeed incorrect and you might need to check these, but let me give you a quick overview first.
In PyTorch most layers expect an input tensor in the shape [batch_size, *], where * denotes additional dimensions. There are a few exceptions, such as RNNs, which expect the batch dimension in dim1 by default (you can change it via bathc_first=True), and the latest PyTorch release also accepts inputs without the batch dimension (it will unsqueeze the missing batch dimension for you), but let’s stick to the common approach.
All intermediate acitvations will thus also have a shape of [batch_size, *], but note that the trainable parameters and buffers do not depend on the batch size and thus do not have a batch dimension.
Take a look at this simple example:

lin = nn.Linear(8, 16)
print(lin.weight.shape)
# torch.Size([16, 8])
print(lin.bias.shape)
# torch.Size([16])

x = torch.randn(1, 8)
out = lin(x)
print(out.shape)
# torch.Size([1, 16])

x = torch.randn(20, 8)
out = lin(x)
print(out.shape)
# torch.Size([20, 16])

You can see that the weight and bias parameters do not care about the actual batch size and work with different inputs.

In your posted output a few tensors seem to change their shape if a batch size of 1 was used for the input:

# from
Coef c shape: torch.Size([64])
# to
Coef c shape: torch.Size([3000, 64])

which seems unexpected.
I don’t know which exact parameter this is, but even if this would be a forward activation I would expect to see a new dimension with a size of 1 (batch_size) instead of 3000.
Could you check which parameter this is and how its shape depends on the batch size?