_scaled_mm is defined at pytorch/aten/src/ATen/native/native_functions.yaml at release/2.5 · pytorch/pytorch · GitHub, its out_dtype can be different than the input dtype. And so I’m curious about the dtype things in the backward functions. I checked in file pytorch/tools/autograd/derivatives.yaml at release/2.5 · pytorch/pytorch · GitHub, and do not see anything about scaled_mm.
so, my question is, how does the autograd know what’s the backward function for _scaled_mm (I know the backward function is also mm, but how are they connected in code)? thanks.