Hi there, I have a model where torch.compile() with torch.set_float32_matmul_precision("highest")
has no accuracy issues, but with torch.set_float32_matmul_precision("high")
or lower has accuracy issues. As I understand, there is no guarantee that the numerics will not differ from eager, so I’m looking for a workaround. Does torch.compile have any built in way or decorator where I can disable/turn off TF32 for some of the modules or functions in the nn.Module
so I can close the accuracy gap?