I am currently using rtx4060 for desktop, and Jetson AGX Orin board for research.
When I make inference using Llama3.2 1B model, whether I am using kvCache(past_key_values) changes the attention implementation (backend)
In rtx4060 setup, using kvCache reduced the inference time, but in Jetson board, using kvCache increased the inference time.
According to pytorch profiler, I found that Jetson’s choice for attention when using kvCache was scaled_dot_product_cudnn_attention. And this takes extremely long time on first iteration (over 100times compared to without cache version)
Any one knows why this happended and how I can fix this?
The below image is a snapshot of profiler