Do I need a custom modified MultiheadAttention layer?

class Foo(nn.Module):
    def __init__(self, dim, n_iter):
        ...
        self.n_iter = n_iter
        self.attn = nn.MultiheadAttention(dim, 1)

    def forward(self, input):
        input = input.unsqueeze(0)

        for i in range(self.n_iter):
            assert input.size(0) == i+1
            output, _ = self.attn(input, input, input, need_weights=False)

            # For every iteration, I need only the last feature
            # since it uses only the previous input (like rnns)
            last = output[-1]

            # The input is concatenated with previous input,
            # that means some results of the previous `attn` can be reused.
            input = torch.cat([input, some_func(last)], dim=0)
        ...

Updated to make the question clearer.

I want to solve a problem iteratively, similar to RNNs but using self-attention.
I have a series of inputs, I want the first output depends on the first input, and the second output depends on both the first and second input and so on.
But the things different is, the second input depends on the first output, so that’s why I do it iteratively.

I think MultiheadAttention is good for solving this problem but it brings a lot of unnecessary calculation.

  1. For each iteration, I need only the last output.
  2. The later iteration can reuse some hidden state of the previous iteration (like qkv of the previous input).

Do I need a custom modified MultiheadAttention layer to only calculate for the last output and save the hidden states so that can be reused later?
If so, how to do it since the forward calls F.multi_head_attention_forward which is not written in python?
Or is there any other way better suited to my needs?

I think anyone who tries to solve a time series using self-attention will meet a similar situation? (except for that the next input depends on the previous output)
But I’m new in this domain and don’t know where to find the solution.
Can anyone help me?

Any help would be appreciated!

If I understood well, I think that you can do that by simply write something like this in the forward function

def forward(self, x):
        previous = x
        x = self.layer_or_model(x)
        x += previous
        return x

and use it as a building block. I believe this kind of solution works because I used it before but I am not sure if this is what you asked for.

Sorry, I didn’t make it clear.

The way I want to solve the problem is similar to RNNs.
I have a series of inputs, I want the first output depends on the first input, the second output depends on the first and second input, and so on.
But the things different is, the second input depends on the first output, so that’s why I do it iteratively.

Thanks for your help!