Strange behavior of Torch.Angle()'s AngleBackward

I am using Torch.Angle() as a part of my model. And when I train my model, I always faces Nan issue on gradients. The errors are as follows.

[W python_anomaly_mode.cpp:104] Warning: Error detected in AngleBackward. Traceback of forward call that caused the error:
File “/home/choi574/.conda/envs/pytorch_mk_190/lib/python3.9/threading.py”, line 930, in _bootstrap
self._bootstrap_inner()
File “/home/choi574/.conda/envs/pytorch_mk_190/lib/python3.9/threading.py”, line 973, in _bootstrap_inner
self.run()
File “/home/choi574/.conda/envs/pytorch_mk_190/lib/python3.9/threading.py”, line 910, in run
self._target(*self._args, **self._kwargs)
File “/home/choi574/.conda/envs/pytorch_mk_190/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py”, line 61, in _worker
output = module(*input, **kwargs)
File “/home/choi574/.conda/envs/pytorch_mk_190/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)
File “/home/choi574/research_mk/fourier_cnn/cnn_fourier1_magPhaseCorr_noBN_noSA_f11_fixPhaseMag_noConjPad_scale/train.py”, line 106, in forward
pred, _ = self.cls_net(img_s)
File “/home/choi574/.conda/envs/pytorch_mk_190/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)
File “/home/choi574/research_mk/fourier_cnn/cnn_fourier1_magPhaseCorr_noBN_noSA_f11_fixPhaseMag_noConjPad_scale/model_fourier1_magPhaseCorr_noBN_noSA_f11l_fixPhaseMag_noConjPad_scale.py”, line 137, in forward
out1_i, f1_res = self.conv1_2(out1_res)
File “/home/choi574/.conda/envs/pytorch_mk_190/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)
File “/home/choi574/research_mk/fourier_cnn/cnn_fourier1_magPhaseCorr_noBN_noSA_f11_fixPhaseMag_noConjPad_scale/model_fourier1_magPhaseCorr_noBN_noSA_f11l_fixPhaseMag_noConjPad_scale.py”, line 64, in forward
f_conv_i = f_scale.angle() #(self.sigmoid(self.conv_i(f_pad_pos_concat))-0.5)23.141592
(function _print_stack)
Traceback (most recent call last):
File “/home/choi574/research_mk/fourier_cnn/cnn_fourier1_magPhaseCorr_noBN_noSA_f11_fixPhaseMag_noConjPad_scale/train.py”, line 292, in
main()
File “/home/choi574/research_mk/fourier_cnn/cnn_fourier1_magPhaseCorr_noBN_noSA_f11_fixPhaseMag_noConjPad_scale/train.py”, line 83, in main
train(args, train_loader, test_loader)
File “/home/choi574/research_mk/fourier_cnn/cnn_fourier1_magPhaseCorr_noBN_noSA_f11_fixPhaseMag_noConjPad_scale/train.py”, line 230, in train
loss_b.backward()
File “/home/choi574/.conda/envs/pytorch_mk_190/lib/python3.9/site-packages/torch/_tensor.py”, line 255, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File “/home/choi574/.conda/envs/pytorch_mk_190/lib/python3.9/site-packages/torch/autograd/init.py”, line 147, in backward
Variable._execution_engine.run_backward(
RuntimeError: Function ‘AngleBackward’ returned nan values in its 0th output.

Always the AngleBackward is causing an issue of Nan. So, I simplified my code as follows. In the code, there is Torch.Angle() between the FFT and iFFT. Basically, the block the Torch.Angle() is included does not add any new information. It just divides the frequency into magnitude and phase, and then merge them into frequency domain (complex number) again. So, commenting the block including Torch.Angle() does not make output different (please see the code below for details).

    class Model(torch.nn.Module):
      def __init__(self):
    
        self.conv = nn.Conv2d(3, 64)
        self.relu = nn.ReLU()
  
      def forward(self, x):
        '''
        x: image
        '''
    
        out = self.relu(self.conv(x))
    
        ## FFT
        fft2 = torch.fft.fftshift(torch.fft.fft2(x), [-2, -1])
    

        ######### This block does nothing in overall #########
        ## When this block is commented, there is no Nan issue in training. 
        ## When this block is not commented, there is Nan issue in Angle(). 
        a = fft2.angle()
        m = fft2.abs()
        
        dum = torch.zeros_like(a).unsqueeze(-1)
        aj = torch.view_as_complex(torch.cat((dum, a.unsqueeze(-1)), -1))
        aj_exp = torch.exp(aj)
        x_res = m * aj_exp
        ########################################################
    
        ## iFFT
        ifft2 = torch.fft.ifft2(torch.fft.ifftshift(x_res, [-2, -1]))
    
        return ifft2

Seeing the fact that the Nan issue always occurs from Torch.Angle(), I tried to search if there is similar case in the Internet. And there is a previous discussion with the same topic ('AngleBackward' returned nan values - #3 by eagomez). Basically, the solution from the discussion is to eliminate the (near) zero values from the input to the Torch.Angle(), which is fft2 in my code. I think it is due to avoid dividing by zero inside Torch.Angle(). When I use the suggested method from the discussion, my Nan issue is gone.

Therefore, it seems the Torch.Angle() returns Nan value for some cases, but this issue is gone when the suggested solution is applied to the input of the Torch.Angle().

I wonder why Torch.Angle() returns Nan for some cases and why the suggested solution solves the Nan issue. The code above is used as a part of a simple classification model using ImageNet. I tried to manually make AngleBackward to return Nan value, but I could not. It only occurs when the input is meets certain conditions.
Could anyone help me to understand why and how AngleBackward returns Nan?