This is about optimizing cuDNN to avoid rebuilding the graph when batch, input dimensions change.

According to the code in PyTorch at this link, it shows that when running the SDPA with cuDNN multiple times, if the batch size (B) or sequence length (S) of the query, key, or value changes, the cudnn graph needs to be rebuilt each time. This is quite time-consuming. Could you kindly advise if there is any way to avoid rebuilding the cudnn graph multiple times?