[Torch Dynamo] too many subgraphs for LLM model inference due to mask


I plan to use torch.compile on models like Llama to do some optimizations. For example, to replace the plain implementing of attention to fused attention.
However, dynamo generates hundreds of subgraphs. A lot of them are due to the mask update during each token generation.

Is there anyway to work around it?


You can try setting dynamic=True- alternatively i’ve seen some oss implementations preallocate the mask so it’s static