Couple of questions.
-
What kernels are referred to in the following line
TORCH_WARN(“Both fused kernels do not support non-zero dropout.”);
from
pytorch/aten/src/ATen/native/transformers/sdp_utils_cpp.h at 51f91e3428be9201e9c1199ec704bc132f979955 · pytorch/pytorch · GitHub -
Is there a way to see what attention kernel torch picked up for running ?