I was trying to implement the sliding window trick for BERT so that we can process long sequences. I am trying to implement it in plain PyTorch. I am unable to implement in Batches without running any loops. Basically, what is bothering me is how can we split a long sequence and then after getting the embeddings, how do we unpack them without running loops? If someone is familiar with the process, is it possible to describe the trick?
More detail about the issue - https://github.com/google-research/bert/issues/66
Any help would be appreciated.