Second Order Derivative with Nan Value - RuntimeError: Function 'SigmoidBackwardBackward0' returned nan values in its 0th output

Hello,

I am trying to find the second order derivative of a model, specifically loss with respect model parameters. However, what I have got were tensors with value nan

The first order derivative was computed with loss.backward(create_graph = True), and the value turned out to be fine, while the second order derivative was computed with torch.autograd.grad(grad, param, v, retain_graph = False), which gave nan

I have refered to Gradient value is nan and set torch.autograd.set_detect_anomaly(True) in the python script.

Warnings and errors turned out to be:
UserWarning: Error detected in SigmoidBackwardBackward0.
RuntimeError: Function 'SigmoidBackwardBackward0' returned nan values in its 0th output.

The full trace back:

Previous calculation was induced by SigmoidBackward0. Traceback of forward call that induced the previous calculation:
  File "/home/qw/anaconda3/envs/dl/lib/python3.8/threading.py", line 890, in _bootstrap
    self._bootstrap_inner()
  File "/home/qw/anaconda3/envs/dl/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/home/qw/anaconda3/envs/dl/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/qw/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/home/qw/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/qw/qw_proj/Pruning/model/clip_img.py", line 21, in forward
    x = self.model(x)#.half()
  File "/home/qw/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/qw/qw_proj/Pruning/CLIP/clip/model.py", line 236, in forward
    x = self.transformer(x)
  File "/home/qw/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/qw/qw_proj/Pruning/CLIP/clip/model.py", line 207, in forward
    return self.resblocks(x)
  File "/home/qw/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/qw/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/home/qw/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/qw/qw_proj/Pruning/CLIP/clip/model.py", line 195, in forward
    x = x.clone() + self.mlp(self.ln_2(x.clone()))
  File "/home/qw/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/qw/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/home/qw/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/qw/qw_proj/Pruning/CLIP/clip/model.py", line 170, in forward
    x = x.clone() * torch.sigmoid(1.702 * x.clone())
  File "/home/qw/anaconda3/envs/dl/lib/python3.8/site-packages/torch/fx/traceback.py", line 57, in format_stack
    return traceback.format_stack()
 (Triggered internally at /opt/conda/conda-bld/pytorch_1666642975312/work/torch/csrc/autograd/python_anomaly_mode.cpp:121.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

I would be grateful if you could suggest me any potential solution.

Thanks in advance.