I have written the following code to mimic tf.scan
def scan(foo, x):
res = []
res.append(x[0].unsqueeze(0))
a_ = x[0].clone()
for i in range(1, len(x)):
res.append(foo(a_, x[i]).unsqueeze(0))
a_ = foo(a_, x[i])
return torch.cat(res)
It generates the desired output for a number of examples. My only question is if the append and torch.cat part in this work break the backpropagation computations.
No, backpropagation will not be broken by append(). Even though
you are appending to a python list, you’re still (presumably) appending
a valid pytorch tensor, and cat() is a valid pytorch tensor operation.
(Of course, something in foo() might break backpropagation.)
Here is a simple script that illustrates backpropagating through append() and cat():