The backward for _scaled_mm

_scaled_mm is defined at pytorch/aten/src/ATen/native/native_functions.yaml at release/2.5 · pytorch/pytorch · GitHub, its out_dtype can be different than the input dtype. And so I’m curious about the dtype things in the backward functions. I checked in file pytorch/tools/autograd/derivatives.yaml at release/2.5 · pytorch/pytorch · GitHub, and do not see anything about scaled_mm.

so, my question is, how does the autograd know what’s the backward function for _scaled_mm (I know the backward function is also mm, but how are they connected in code)? thanks.

_scaled_mm is the raw operation while the Autograd usage would depend on derived dtypes as described here.

thanks, a nice introduction on fp8.