# How to reuse precalculated attention weights for autoregressive transformers

Hi when I do an inference (not training) of my autoregressive transformer I do it substantially this way (I removed few lines to not affect readibility):

``````for i in range(max_batch_sequence_len):
for layer in self.layers:
y[:, i] = layer(x, keep_mask, y)[:, i]
``````

where my layers "forward’ are:

``````def forward(self, x: torch.Tensor, keep_mask: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

y_normed = self.layer_norm(y)
y = y + self.self_attn(y_normed) #A causal mask is applied

x_normed = self.layer_norm(x)
y_normed = self.layer_norm(y)
y = y + self.cross_attn(y_normed, x_normed, attn_mask)

y_normed = self.layer_norm(y)
y = y + self.ffn(y_normed)
return y

def self_attn(self, y):
out, _ = self.attn1(
query=y, key=y, value=y, need_weights=False, is_causal=True,
)
return out

out, _ = self.attn2(
)
return out
``````

I can see that using the attention weights and re-inputing them in a certain way I can manage to reduce the computation, especially at step i+1 the attention weights for j<=i have all been already computed.

Has someone here have ever dealt with that and can suggest me a modification of my code?

Hi I’m not sure if this will help solve your problem but just to give some ideas, typically when I have used transformers i’d do something like this for inference (we wouldn’t want to instantiate every time of course) -

``````model = MyTransformer(**args)
model = torch.compile(model)