This question has been asked in 2017, but I’m wondering now (2020) if there is any way to do tf.scan or theano.scan
in PyTorch.
In PyTorch, there is torch.cumsum
, which can be thought of as a special case of scan
. Specifically, cumsum
is tied to the addition operator, whereas in TensorFlow or Theano, scan
can be used with any binary operator (passed in as a function), not just addition.
From the 2017 thread (linked above), @jekbradbury said:
For equivalents of
theano.scan
, use Pythonfor
andwhile
loops.
However, I would prefer not to use for
and while
loops, in part because I want my computation to be done in parallel on the GPU. Scans, a.k.a. prefix sums, are a common GPGPU operation, which is why CUDA libraries like Thrust contain a variety of prefix sum functions.
Has anyone (officially or unofficially) implemented a scan
in PyTorch, or am I stuck with Python for
and while
loops?