What options do I have when torch.compile with TF32 matmuls causes accuracy issues?

Hi there, I have a model where torch.compile() with torch.set_float32_matmul_precision("highest") has no accuracy issues, but with torch.set_float32_matmul_precision("high") or lower has accuracy issues. As I understand, there is no guarantee that the numerics will not differ from eager, so I’m looking for a workaround. Does torch.compile have any built in way or decorator where I can disable/turn off TF32 for some of the modules or functions in the nn.Module so I can close the accuracy gap?