I would like to translate the following jax code (that implements a Kalman filter) to torch.
def kf(params, emissions, return_covs=False):
F, Q, R = params['F'], params['Q'], params['R']
def step(carry, t):
ll, pred_mean, pred_cov = carry
H = params['Ht'][t]
y = emissions[t]
ll += MVN(H @ pred_mean, H @ pred_cov @ H.T + R).log_prob(y)
filtered_mean, filtered_cov = condition_on(pred_mean, pred_cov, H, R, y)
pred_mean, pred_cov = predict(filtered_mean, filtered_cov, F, Q)
carry = (ll, pred_mean, pred_cov)
if return_covs:
return carry, (filtered_mean, filtered_cov)
else:
return carry, (filtered_mean, None)
num_timesteps = len(emissions)
carry = (0.0, params['mu0'], params['Sigma0'])
(ll, _, _), (filtered_means, filtered_covs) = lax.scan(step, carry, jnp.arange(num_timesteps))
return ll, filtered_means, filtered_covs
The torch code below works, but is ~5x slower (on CPU), presumably because of the for loop.
Is there an equivalent of scan?
def kf_pt(params, emissions, return_covs=False, compile=False):
F, Q, R = params['F'], params['Q'], params['R']
def step(carry, t):
ll, pred_mean, pred_cov = carry
H = params['Ht'][t]
y = emissions[t]
#ll += MVN(H @ pred_mean, H @ pred_cov @ H.T + R).log_prob(y)
filtered_mean, filtered_cov = condition_on_pt(pred_mean, pred_cov, H, R, y)
pred_mean, pred_cov = predict_pt(filtered_mean, filtered_cov, F, Q)
carry = (ll, pred_mean, pred_cov)
if return_covs:
return carry, (filtered_mean, filtered_cov)
else:
return carry, filtered_mean
if compile:
step = torch.compile(step)
num_timesteps = len(emissions)
D = len(params['mu0'])
filtered_means = torch.zeros((num_timesteps, D))
if return_covs:
filtered_covs = torch.zeros((num_timesteps, D, D))
else:
filtered_covs = None
ll = 0
carry = (ll, params['mu0'], params['Sigma0'])
for t in range(num_timesteps):
if return_covs:
carry, (filtered_means[t], filtered_covs[t]) = step(carry, t)
else:
carry, filtered_means[t] = step(carry, t)
return ll, filtered_means, filtered_covs