Hello everybody! This neural network training code snippet uses the fftn function from the Jax library so that I can use mixed precision, as I cannot use mixed precision using PyTorch’s fftn function as there is restriction for that (input size must be power of 2, and in my case it is not). The same structure is used for the validation part and for the test.
I ask: because I don’t use PyTorch’s fftn is there a problem in my code regarding gradient calculations? If so, is this something serious that affects my network performance?
Thank you very much for answering. What are the impacts of this destruction of the gradient calculation? I couldn’t do the padding and then make this post-calculation adjustment, it would be too rough. It’s easier if I just don’t run the code with mixed precision anymore (unfortunately).
Well, if you detach a Tensor it deletes its gradient by definition. So, you’ll get incorrect gradients or no gradients at all. Here’s an example,
x = torch.randn(1, requires_grad=True)
x_detach = x.clone() #copy so we can compare results (with same input)
y = x**2
y.backward() #compute gradient of leaf tensors
#returns tensor([1.4976], requires_grad=True) tensor([2.9953])
y = (x_detach**2).detach() #now repeat with .detach()
y.backward() #crashes here as no grad function
Ok. Is there a way for me to assess the level of “impact” on the loss caused by this gradient destruction? I’ve done hundreds of experiments considering this code, should I disregard them because of what happened? I mean, are the results invalid? Thanks again!
Well, you’d have to check your code, but if you want to backpropagate through your FFT call, and you’ve detached your FFT results (via moving to JAX, for example), then I would think there’s a mistake.
You can always run your code again fully in pytorch, and see if there are any differences. Between the ‘new’ version and the ‘old’ version.
Ok, at least I had the idea to ask this around here and you showed up to help me, I’ll investigate this further. Considering that the results are stochastic with each execution, how can I compare the values? Should I just look at the loss value for each epoch, for example, with and without this propagation? Thanks.
Seems right, check the convergence between the two versions and if they match, they’re giving the same results. If you get small differences that’ll most likely be from floating point arithmetic, so don’t worry if there’s a small difference.