class Foo(nn.Module):
def __init__(self, dim, n_iter):
...
self.n_iter = n_iter
self.attn = nn.MultiheadAttention(dim, 1)
def forward(self, input):
input = input.unsqueeze(0)
for i in range(self.n_iter):
assert input.size(0) == i+1
output, _ = self.attn(input, input, input, need_weights=False)
# For every iteration, I need only the last feature
# since it uses only the previous input (like rnns)
last = output[-1]
# The input is concatenated with previous input,
# that means some results of the previous `attn` can be reused.
input = torch.cat([input, some_func(last)], dim=0)
...
Updated to make the question clearer.
I want to solve a problem iteratively, similar to RNNs but using self-attention.
I have a series of inputs, I want the first output depends on the first input, and the second output depends on both the first and second input and so on.
But the things different is, the second input depends on the first output, so that’s why I do it iteratively.
I think MultiheadAttention
is good for solving this problem but it brings a lot of unnecessary calculation.
- For each iteration, I need only the last output.
- The later iteration can reuse some hidden state of the previous iteration (like qkv of the previous input).
Do I need a custom modified MultiheadAttention layer to only calculate for the last output and save the hidden states so that can be reused later?
If so, how to do it since the forward
calls F.multi_head_attention_forward
which is not written in python?
Or is there any other way better suited to my needs?
I think anyone who tries to solve a time series using self-attention will meet a similar situation? (except for that the next input depends on the previous output)
But I’m new in this domain and don’t know where to find the solution.
Can anyone help me?
Any help would be appreciated!