TensorFlow's tf.scan or Theano's theano.scan equivalent (2020)

This question has been asked in 2017, but I’m wondering now (2020) if there is any way to do tf.scan or theano.scan in PyTorch.

In PyTorch, there is torch.cumsum, which can be thought of as a special case of scan. Specifically, cumsum is tied to the addition operator, whereas in TensorFlow or Theano, scan can be used with any binary operator (passed in as a function), not just addition.

From the 2017 thread (linked above), @jekbradbury said:

For equivalents of theano.scan , use Python for and while loops.

However, I would prefer not to use for and while loops, in part because I want my computation to be done in parallel on the GPU. Scans, a.k.a. prefix sums, are a common GPGPU operation, which is why CUDA libraries like Thrust contain a variety of prefix sum functions.

Has anyone (officially or unofficially) implemented a scan in PyTorch, or am I stuck with Python for and while loops?

1 Like

Did you try loop with TorchScript?

Did you find an answer for your question?

Nope, I never found an answer. I ended up using a combination of torch.cumsum and torch.cumprod to implement the operation that I needed, but it would still be convenient for PyTorch to have a scan primitive that accepts an arbitrary binary operator, like in TensorFlow/Theano or even Thrust.

2 Likes