FP32 with TF32 precision

I’m using PyTorch with V100 GPU. As this GPU doesn’t support operations in TF32, I’m adjusting my x (input to the prediction model) and y (ground truth) tensors that are in FP32 to have 10-bit precision in the decimal places, the same way TF32 is represented, just using, for example, x = torch.round(x, decimals=4) (I’m using 4 decimal places following instructions from this site - FP64, FP32, FP16, BFLOAT16, TF32, and other members of the ZOO | by Grigory Sapunov | Medium, in the TF32 section). Would this rounding be enough for me to make the FP32 very close to what the TF32 would be? Doing that (considering that the way I’m doing is correct), should I also reduce model precision by performing model.half()? I’m making these adjustments because for some reason my model converges well on Ampere GPU (RTX A4000) but the same does not happen with Volta GPU (V100). I’m guessing it’s because I no longer use TF32 in the operations.
Thanks in advance.

No, direct rounding won’t match your A4000, as e.g. accumulations are performed in FP32 as described here. Also, only convolutions are using TF32 by default while matmuls can use TF32 in newer PyTorch releases if you allow it via: torch.backends.cuda.matmul.allow_tf32 = True. Native PyTorch layers will not use TF32.

Ok, thanks! So what do you suggest I do with my data and, if necessary, in my code, so that the execution on the V100 behaves as closely as possible as if I were running on an Ampere GPU as I described, since I suspect it is due to not using more TF32 on the V100?

I would probably do the opposite experiment and disable TF32 on your Ampere device. If it’s still converging another issues seems to cause the divergence on your V100.

Ok, but at the moment I no longer have access to an Ampere GPU. What would you suggest regarding my last question?

I haven’t seen a model failing to converge in a larger dtype, so think the likelihood of your training depending on TF32 for convergence is low. Also, I would not try to “simulate” TF32 behavior on your V100 without a proper reference as you might also need to write custom kernels.

Ok. Well, maybe my data is special, there are considerable decimal places in my data, so precision can play a role. I believe that training only with FP32 requires that the model needs to make more effort to learn, because the accuracy is greater.

Additionally, what I see is that the L1 of my model ends up being the same (but the main loss and other metrics does not manage to converge well, which ends up harming the results in general).

See below, also, a part of my code. A detail is, where I calculate the losses, the FFT is one of them, using the Ampere GPU turns out to have high values, certainly making it if not the main factor, but one of the main ones, in harming the value of the loss. Also, where I calculate FFT I am using the JAX API, before using it I was not able to obtain the desired L1 (probably because in this case, as the FFT function in torch does not allow me to use AMP in the section where I calculate the FFT (at least in my case where the shape of my data is not a power of 2), the processing is carried out in FP32 , which ends up harming the result), which matches the result when using the Ampere GPU. But whether using JAX or not, the FFT results I use as metrics (fft_l1, fft_l1_rel and fft_l1_max) remain high. I also needed to use AMP to improve my L1.

scaler = torch.cuda.amp.GradScaler()  
import os
import jax   

for seq in seq_train:
    with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):
        X, y = seq
        X = X.cuda()
        y = y.cuda()
        output = net(X)
        #ffty = torch.fft.rfftn(y.float(), dim=(-1, -2, -3))              
        #fftout = torch.fft.rfftn(output.float(), dim=(-1, -2, -3))                
        ffty = jax.numpy.fft.fftn(y.detach().cpu().numpy(), axes=(-1,-2,-3))
        ffty = torch.from_numpy(np.array(ffty))
        fftout = jax.numpy.fft.fftn(output.detach().cpu().numpy(), axes=(-1,-2,-3))
        fftout = torch.from_numpy(np.array(fftout))
        fft_res = (fftout - ffty)
        fft_res_abs = torch.abs(fft_res)
        loss_fft_l1 = torch.mean(fft_res_abs)
        loss_fft_l1_rel = torch.mean(fft_res_abs/(torch.abs(ffty) + 0.01))
        fft_l1_max = torch.max(fft_res_abs)

        res = (output - y)
        res_abs = torch.abs(res)
        loss_l1 = torch.mean(res_abs)
        loss_l1_rel = torch.mean(res_abs/(torch.abs(y)+0.01))
        l1_max = torch.max(res_abs)
        loss = loss_l1_rel + loss_fft_l1_rel

Well, I hope that with these new information you can give me some new tip or suggestion. Thank you very much.

Hi @ptrblck, I don’t know if by chance you didn’t see my last answer, but if so, please do whatever you can to help me, I appreciate it! Thank you very much in advance!