Hi when I do an inference (not training) of my autoregressive transformer I do it substantially this way (I removed few lines to not affect readibility):
for i in range(max_batch_sequence_len):
for layer in self.layers:
y[:, i] = layer(x, keep_mask, y)[:, i]
where my layers "forward’ are:
def forward(self, x: torch.Tensor, keep_mask: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
attn_mask = (~keep_mask).unsqueeze(2) & (~keep_mask).unsqueeze(2)
attn_mask = attn_mask.repeat_interleave(self.num_heads, dim=0)
y_normed = self.layer_norm(y)
y = y + self.self_attn(y_normed) #A causal mask is applied
x_normed = self.layer_norm(x)
y_normed = self.layer_norm(y)
y = y + self.cross_attn(y_normed, x_normed, attn_mask)
y_normed = self.layer_norm(y)
y = y + self.ffn(y_normed)
return y
def self_attn(self, y):
out, _ = self.attn1(
query=y, key=y, value=y, need_weights=False, is_causal=True,
)
return out
def cross_attn(self, y, x, attn_mask):
out, _ = self.attn2(
query=y, key=x, value=x, need_weights=False, attn_mask=attn_mask
)
return out
I can see that using the attention weights and re-inputing them in a certain way I can manage to reduce the computation, especially at step i+1 the attention weights for j<=i have all been already computed.
Has someone here have ever dealt with that and can suggest me a modification of my code?