Is there an equivalent of jax.lax.scan (eg in torch.func)?

I would like to translate the following jax code (that implements a Kalman filter) to torch.

def kf(params, emissions, return_covs=False):
    F, Q, R = params['F'], params['Q'], params['R']
    def step(carry, t):
        ll, pred_mean, pred_cov = carry
        H = params['Ht'][t]
        y = emissions[t]
        ll += MVN(H @ pred_mean, H @ pred_cov @ H.T + R).log_prob(y)
        filtered_mean, filtered_cov = condition_on(pred_mean, pred_cov, H, R, y)
        pred_mean, pred_cov = predict(filtered_mean, filtered_cov, F, Q)
        carry = (ll, pred_mean, pred_cov)
        if return_covs:
            return carry, (filtered_mean, filtered_cov)
        else:
            return carry, (filtered_mean, None)     
    
    num_timesteps = len(emissions)
    carry = (0.0, params['mu0'], params['Sigma0'])
    (ll, _, _), (filtered_means, filtered_covs) = lax.scan(step, carry, jnp.arange(num_timesteps))
    return ll, filtered_means, filtered_covs

The torch code below works, but is ~5x slower (on CPU), presumably because of the for loop.
Is there an equivalent of scan?

def kf_pt(params, emissions, return_covs=False, compile=False):
    F, Q, R = params['F'], params['Q'], params['R']
    def step(carry, t):
        ll, pred_mean, pred_cov = carry
        H = params['Ht'][t]
        y = emissions[t]
        #ll += MVN(H @ pred_mean, H @ pred_cov @ H.T + R).log_prob(y)
        filtered_mean, filtered_cov = condition_on_pt(pred_mean, pred_cov, H, R, y)
        pred_mean, pred_cov = predict_pt(filtered_mean, filtered_cov, F, Q)
        carry = (ll, pred_mean, pred_cov)
        if return_covs:
            return carry, (filtered_mean, filtered_cov)
        else:
            return carry, filtered_mean
    
    if compile:
        step = torch.compile(step)
    num_timesteps = len(emissions)
    D = len(params['mu0'])
    filtered_means = torch.zeros((num_timesteps, D))
    if return_covs:
        filtered_covs = torch.zeros((num_timesteps, D, D))
    else:
        filtered_covs = None
    ll = 0
    carry = (ll, params['mu0'], params['Sigma0'])
    for t in range(num_timesteps):
        if return_covs:
            carry, (filtered_means[t], filtered_covs[t]) = step(carry, t)
        else:
            carry, filtered_means[t] = step(carry, t)
    return ll, filtered_means, filtered_covs
1 Like

Doesn’t exist quite yet but feel free to add your support in this thread where it’s being discussed [feature request] torch.scan (also port lax.fori_loop / lax.while_loop, lax.cond / lax.switch?) · Issue #50688 · pytorch/pytorch · GitHub

1 Like