Memory friendly transformers over (too) long time series

I am doing time series classification and I want to experiment with transformers. Given that transformers’ memory usage grows as square of the input length, the whole sequence would be much grater than the maximum input size of the classifier. So I am trying to come up with ways to encode disjoint sections of the sequence first, save these encodings into a tensor and then feed the tensor to the classifier.

This is a very naive way to do it:

encoder = MyEncoder() # encodes a sequence of size window size into a scalar (for simplicity)
classifier = MyClassifier() # classifies a sequence of size m >= n_windows

seq = torch.tensor([..my pretty long batch sequence ..])

window_size = ... some integer
n_windows = len(seq) // window_size
encoded_seq = torch.zeros(n_window)
for i in range(n_windows):
      window = seq[i*window_size : (i+1)*window_size)]
      encoded_seq[i] = encoder(window)

output = classifier(encoded_seq)

Is there a more efficient way to do this?

I thought of using tensor.apply_ but that too seems slow.