Using Jax code with PyTorch code

Hello everybody! This neural network training code snippet uses the fftn function from the Jax library so that I can use mixed precision, as I cannot use mixed precision using PyTorch’s fftn function as there is restriction for that (input size must be power of 2, and in my case it is not). The same structure is used for the validation part and for the test.

I ask: because I don’t use PyTorch’s fftn is there a problem in my code regarding gradient calculations? If so, is this something serious that affects my network performance?

Thanks in advance!

for idx_epoch, epoch in enumerate(t):
        print()
        list_train_loss = []
        list_train_loss_l1 = []
        list_train_loss_l1_rel = []
        list_train_l1_max = []
        list_train_loss_fft_l1 = []
        list_train_loss_fft_l1_rel = []
        list_train_fft_l1_max = []
        
        list_validation_loss = []
        list_validation_loss_l1 = []
        list_validation_loss_l1_rel = []
        list_validation_l1_max = []
        list_validation_loss_fft_l1 = []
        list_validation_loss_fft_l1_rel = []
        list_validation_fft_l1_max = []
        
        list_test_loss = []
        list_test_loss_l1 = []
        list_test_loss_l1_rel = []
        list_test_l1_max = []
        list_test_loss_fft_l1 = []
        list_test_loss_fft_l1_rel = []
        list_test_fft_l1_max = []
    
        #l1_distribution_validation = []
        #l1_distribution_test = []
    
        #model.train()
        net.train()
               
        scaler = torch.cuda.amp.GradScaler()  
        import jax
        #os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
        os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.50'
        import time	
        for seq in seq_train:
                
            with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):

                X, y = seq
                X = X.cuda()
                y = y.cuda()

                optimizer.zero_grad()
                #output = model(X)
                output = net(X)
           
                #ffty = torch.fft.rfftn(y.float(), dim=(-1, -2, -3))              
                #fftout = torch.fft.rfftn(output.float(), dim=(-1, -2, -3))                
                #ffty = torch.fft.rfftn(y, dim=(-1, -2, -3))
                #fftout = torch.fft.rfftn(output, dim=(-1, -2, -3))

                ffty = jax.numpy.fft.fftn(y.detach().cpu().numpy(), axes=(-1,-2,-3))
                ffty = torch.from_numpy(np.array(ffty))
                fftout = jax.numpy.fft.fftn(output.detach().cpu().numpy(), axes=(-1,-2,-3))
                fftout = torch.from_numpy(np.array(fftout))
                #ffty = jax.numpy.fft.rfft(y.detach().cpu().numpy(), 1)
                #ffty = torch.from_numpy(np.array(ffty))
                #fftout = jax.numpy.fft.rfft(output.detach().cpu().numpy(), 1)                
                #fftout = torch.from_numpy(np.array(fftout))

                #ffty = torch.fft.rfft(y, 1)
                #fftout = torch.fft.rfft(output, 1)
                
                fft_res = (fftout - ffty)
                fft_res_abs = torch.abs(fft_res)
                loss_fft_l1 = torch.mean(fft_res_abs)
                loss_fft_l1_rel = torch.mean(fft_res_abs/(torch.abs(ffty) + 0.01))
                fft_l1_max = torch.max(fft_res_abs)
    
                res = (output - y)
                res_abs = torch.abs(res)
                loss_l1 = torch.mean(res_abs)
                loss_l1_rel = torch.mean(res_abs/(torch.abs(y)+0.01))
                l1_max = torch.max(res_abs)
    
                loss = loss_l1_rel + loss_fft_l1_rel
            
            #loss.backward()
            scaler.scale(loss).backward()
            #optimizer.step()
            scaler.step(optimizer)
            scaler.update() 
        
            list_train_loss.append(loss.cpu().item())
            list_train_loss_l1.append(loss_l1.cpu().item())
            list_train_loss_l1_rel.append(loss_l1_rel.cpu().item())
            list_train_l1_max.append(l1_max.cpu().item())
            list_train_loss_fft_l1.append(loss_fft_l1.cpu().item())
            list_train_loss_fft_l1_rel.append(loss_fft_l1_rel.cpu().item())
            list_train_fft_l1_max.append(fft_l1_max.cpu().item())

Hi @marco_c,

As you’re detaching your inputs, in order to pass your data to JAX, you’re effectively destroying the gradient calculation. (Via the .detach() function)

You could try padding your input so that it becomes a power of 2? Then just removed the padded results after the FFT call?

Hi @AlphaBetaGamma96,

Thank you very much for answering. What are the impacts of this destruction of the gradient calculation? I couldn’t do the padding and then make this post-calculation adjustment, it would be too rough. It’s easier if I just don’t run the code with mixed precision anymore (unfortunately).

Well, if you detach a Tensor it deletes its gradient by definition. So, you’ll get incorrect gradients or no gradients at all. Here’s an example,

import torch

x = torch.randn(1, requires_grad=True)
x_detach = x.clone() #copy so we can compare results (with same input)

y = x**2 
y.backward() #compute gradient of leaf tensors

print(x, x.grad)
#returns tensor([1.4976], requires_grad=True) tensor([2.9953])

y = (x_detach**2).detach() #now repeat with .detach()

y.backward() #crashes here as no grad function

print(x_detach, x_detach.grad)

Ok, thanks. So this ends up affecting the loss a lot, right?

The funny thing is that despite this I keep managing to reduce the loss, I never realized that there was something wrong.

It depends on the example at hand, but using the .detach() function, especially when computing a loss, should be done with a lot of care in mind.

You’re effectively telling autograd to ignore the gradients that created that tensor, which will affect the gradient calculation.

Ok. Is there a way for me to assess the level of “impact” on the loss caused by this gradient destruction? I’ve done hundreds of experiments considering this code, should I disregard them because of what happened? I mean, are the results invalid? Thanks again!

Well, you’d have to check your code, but if you want to backpropagate through your FFT call, and you’ve detached your FFT results (via moving to JAX, for example), then I would think there’s a mistake.

You can always run your code again fully in pytorch, and see if there are any differences. Between the ‘new’ version and the ‘old’ version.

Ok, I’m testing it. Just for me to understand better, does the value of the loss change with or without this propagation? Thanks.

Well, if the loss is changing there’s a gradient flowing through your program. The question is, is this the correct gradient you should have?

In terms of solving this, that’s something you’ll have to debug yourself, as you know your code better than me! But a good starting point is to run both versions and see if you get the same results.

Ok, at least I had the idea to ask this around here and you showed up to help me, I’ll investigate this further. Considering that the results are stochastic with each execution, how can I compare the values? Should I just look at the loss value for each epoch, for example, with and without this propagation? Thanks.

Have a look at torch.use_deterministic_algorthims in the docs, and make sure to see all the seeds in pytorch/JAX/numpy etc to be the same.

https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html

Ok, I will look at these things. Adjusting these seeds, I would only look at the loss values for each epoch for the comparisons, right? Thanks!

Seems right, check the convergence between the two versions and if they match, they’re giving the same results. If you get small differences that’ll most likely be from floating point arithmetic, so don’t worry if there’s a small difference.

Ok, I’ll do it, thank you very much!!