How to Compute Recurrence Equation in PyTorch?

Hi all,

In my current project, I need to compute the following recurrence relation.

Given a sequence {b} = [b_1, b_2, …, b_n], I would like to compute the sequence {a} as follows:

a_1 = 1
a_2 = a_1 x b_1 + 1

a_n = a_n-1 x b_n-1 + 1

In my problem, b_i, i=1, 2, …, n, are 2-D matrices.

I can use FOR loop to compute the sequence {a}, but it is slow when dealing with very long sequences. I wonder if there is an efficient way to compute the sequence {a} in Pytorch or not.

I note that without “+1”, the sequence {a} is just the cumulative product of the sequence {b}, which can be computed fast with torch.cumprod.

I would like to thank you so much in advance!! :slight_smile:

Tan Nguyen