Is PyTorch’s memory efficient attention implementation of scaled_dot_product_attention
the same as xFormer’s memory_efficient_attention (which uses Flash-Decoding according to Flash-Decoding for long-context inference | PyTorch)? I’m interested in testing Flash-Decoding and PyTorch’s scaled_dot_product_attention page (torch.nn.functional.scaled_dot_product_attention — PyTorch 2.2 documentation) links to xFormer’s GitHub.