How to reuse precalculated attention weights for autoregressive transformers

Hi when I do an inference (not training) of my autoregressive transformer I do it substantially this way (I removed few lines to not affect readibility):

for i in range(max_batch_sequence_len):
    for layer in self.layers:
        y[:, i] = layer(x, keep_mask, y)[:, i]

where my layers "forward’ are:

def forward(self, x: torch.Tensor, keep_mask: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        attn_mask = (~keep_mask).unsqueeze(2) & (~keep_mask).unsqueeze(2)
        attn_mask = attn_mask.repeat_interleave(self.num_heads, dim=0)
            
        y_normed = self.layer_norm(y)
        y = y + self.self_attn(y_normed) #A causal mask is applied

        x_normed = self.layer_norm(x)
        y_normed = self.layer_norm(y)
        y = y + self.cross_attn(y_normed, x_normed, attn_mask)
        
        y_normed = self.layer_norm(y)
        y = y + self.ffn(y_normed)
        return y

def self_attn(self, y):
        out, _ = self.attn1(
            query=y, key=y, value=y, need_weights=False, is_causal=True,
        )
        return out

def cross_attn(self, y, x, attn_mask):

    out, _ = self.attn2(
        query=y, key=x, value=x, need_weights=False, attn_mask=attn_mask
    )
    return out

I can see that using the attention weights and re-inputing them in a certain way I can manage to reduce the computation, especially at step i+1 the attention weights for j<=i have all been already computed.

Has someone here have ever dealt with that and can suggest me a modification of my code?

Hi I’m not sure if this will help solve your problem but just to give some ideas, typically when I have used transformers i’d do something like this for inference (we wouldn’t want to instantiate every time of course) -

model = MyTransformer(**args)
model = torch.compile(model)
model.load_state_dict(**args)

predictions = model(src, tgt)