Using jacobian as loss function leads to inf

HI guys, I encounter a problem. I have a network (named as eikonalmodel in the code below), and I want the norm of the gradient of the network at input points equals some groundtruth value. Here is my code (some step involves select part of the jacobian, and it should be easy to understand):

#    define loss function
loss_fun = nn.MSELoss().to(device)
#    (batch, 5), which ismodel input
EIKONALFORWARDINPUT_microbatch.requires_grad_(True)
#    (batch, 5)-->(batch, 1)
OUT = eikonalmodel(EIKONALFORWARDINPUT_microbatch)
#    JAC shape is (batch, 1, 5). below line is my self-write code to calculate the jacobian using torch.autograd
JAC = eikonalmodel.Jacobian_by_batch(inputt_of_model = EIKONALFORWARDINPUT_microbatch, outputt_of_model = OUT.to(device), create_graph = True, retain_graph = True, allow_unused =  False)
#    (batch,2), just use some of the dim to calculate jac
JAC_xy = JAC[:,0,1:3].to(device)
#    (batch, )
jac_norm = torch.sqrt(torch.pow(JAC_xy, 2).sum(dim = 1))
#    (batch,). 
groundtruth = torch.sqrt(torch.pow(EIKONALGRADIENT_microbatch, 2).sum(dim = 1))
#    The groundtruth value is selected such that greater than some small positive value
loss  = loss_fun(jac_norm, 1.0/groundtruth)
#
optimizer.zero_grad()
accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(parameters, args.grad_clip_value)
optimizer.step()

However, soon the gradient (or jacobian) becomes inf. I tried to decrease the learning rate, or clip the grad norm. The results are as follows:
BATCH: 0%| | 0/3086 [00:00<?, ?it/s]

    microbatches:   0%|          | 0/135 [00:00<?, ?it/s]e[A
  loss -> 95178370.0 ; groundtruth--> 0.00020102841 ; jac_norm maxmin 0.039355658 0.00172582 ;EIKONALFORWARDINPUT_microbatch-->  0 ;OUT--> 0 ;max grad--> 0.0


    microbatches:   1%|          | 1/135 [00:11<26:23, 11.82s/it]e[A
  loss -> 447264030.0 ; groundtruth--> 0.00018992453 ; jac_norm maxmin 0.0043626004 3.0593575e-10 ;EIKONALFORWARDINPUT_microbatch-->  0 ;OUT--> 0 ;max grad--> 31.946300506591797


    microbatches:   1%|▏         | 2/135 [00:14<14:09,  6.39s/it]e[A
  loss -> 189443070.0 ; groundtruth--> 0.0002351522 ; jac_norm maxmin 0.0549314 1.8008926e-05 ;EIKONALFORWARDINPUT_microbatch-->  0 ;OUT--> 0 ;max grad--> 3.616215467453003


    microbatches:   2%|▏         | 3/135 [00:16<10:03,  4.57s/it]e[A
  loss -> 201413380.0 ; groundtruth--> 0.000168255 ; jac_norm maxmin 0.0391026 1.2644264e-05 ;EIKONALFORWARDINPUT_microbatch-->  0 ;OUT--> 0 ;max grad--> 50.89857864379883


    microbatches:   3%|▎         | 4/135 [00:19<08:05,  3.70s/it]e[A
  loss -> 345555520.0 ; groundtruth--> 0.00017404387 ; jac_norm maxmin 7.554859 8.869787e-09 ;EIKONALFORWARDINPUT_microbatch-->  0 ;OUT--> 0 ;max grad--> 22.646554946899414


    microbatches:   4%|▎         | 5/135 [00:21<06:59,  3.23s/it]e[A
  loss -> inf ; groundtruth--> 0.0003031217 ; jac_norm maxmin inf inf

Your loss function is:

loss = loss_fun(jac_norm, 1.0/groundtruth)

When groundtruth contains very small values, the reciprocal 1.0/groundtruth becomes extremely large (around 4975). This creates several issues:

  1. Numerical instability: The MSE loss between small jac_norm values and large reciprocal values produces massive gradients
  2. Gradient accumulation: Through backpropagation, these large gradients compound, eventually leading to infinity

Consider using this stable alternatives

# More stable loss formulation
eps = 1e-8
alpha = 0.1  # weighting factor

# Option 1: Match the ratio instead of using reciprocal
loss = loss_fun(jac_norm * groundtruth, torch.ones_like(jac_norm))

# Option 2: If you must use reciprocal, bound it
safe_target = torch.clamp(1.0 / (groundtruth + eps), min=0, max=100)
loss = loss_fun(jac_norm, safe_target)

# Add regularization
jac_reg = alpha * torch.mean(torch.pow(JAC_xy, 2))
total_loss = loss + jac_reg

Dear Hamza, thank you for your suggestion. As the reciprocal must be used, I checked the value of the reciprocal, max and min are:

 (tensor(479.3397), tensor(999771.3125))

I tried the first and third alternatives, they run a bit more iterations than previous loss definition. However the problem still exists. If no regularization:

BATCH:   0%|          | 0/3086 [00:00<?, ?it/s]

        microbatches:   0%|          | 0/1 [00:00<?, ?it/s]e[A
      loss norm-> 0.9999313 ; groundtruth maxmin--> 0.0006271558 0.0003031217 ; jac_norm maxmin 0.67917717 4.0782093e-09 ;input nan number-->  0 ;OUT nan number--> 0 ;max grad--> 4.516813532973174e-06


        microbatches: 100%|██████████| 1/1 [00:01<00:00,  1.68s/it]e[A

                                                                   e[A
BATCH:   0%|          | 1/3086 [00:02<2:01:26,  2.36s/it]

        microbatches:   0%|          | 0/1 [00:00<?, ?it/s]e[A
      loss norm-> 0.999897 ; groundtruth maxmin--> 0.0006206568 0.00030541696 ; jac_norm maxmin 0.74772996 3.990999e-08 ;input nan number-->  0 ;OUT nan number--> 0 ;max grad--> 6.692044280498521e-06


        microbatches: 100%|██████████| 1/1 [00:00<00:00,  1.03it/s]e[A

                                                                   e[A
BATCH:   0%|          | 2/3086 [00:03<1:37:43,  1.90s/it]

        microbatches:   0%|          | 0/1 [00:00<?, ?it/s]e[A
      loss norm-> 0.99999356 ; groundtruth maxmin--> 0.0006206568 0.00030476737 ; jac_norm maxmin 0.06769135 5.798581e-10 ;input nan number-->  0 ;OUT nan number--> 0 ;max grad--> 4.164444078469387e-07


        microbatches: 100%|██████████| 1/1 [00:00<00:00,  1.08it/s]e[A

                                                                   e[A
BATCH:   0%|          | 3/3086 [00:05<1:31:57,  1.79s/it]

        microbatches:   0%|          | 0/1 [00:00<?, ?it/s]e[A
      loss norm-> 0.99999934 ; groundtruth maxmin--> 0.0010546517 0.00030287757 ; jac_norm maxmin 0.0023109436 1.7275322e-09 ;input nan number-->  0 ;OUT nan number--> 0 ;max grad--> 4.4113612318597006e-08


        microbatches: 100%|██████████| 1/1 [00:00<00:00,  1.67it/s]e[A

                                                                   e[A
BATCH:   0%|          | 4/3086 [00:06<1:21:20,  1.58s/it]

        microbatches:   0%|          | 0/1 [00:00<?, ?it/s]e[A
      loss norm-> 0.99986094 ; groundtruth maxmin--> 0.00097294734 0.00030541696 ; jac_norm maxmin 1.1393903 3.1171703e-09 ;input nan number-->  0 ;OUT nan number--> 0 ;max grad--> 9.107193363888655e-06


        microbatches: 100%|██████████| 1/1 [00:01<00:00,  1.93s/it]e[A

                                                                   e[A
BATCH:   0%|          | 5/3086 [00:09<1:40:15,  1.95s/it]

        microbatches:   0%|          | 0/1 [00:00<?, ?it/s]e[A
      loss norm-> 0.9999345 ; groundtruth maxmin--> 0.0011556828 0.00031172478 ; jac_norm maxmin 0.4403629 5.680148e-07 ;input nan number-->  0 ;OUT nan number--> 0 ;max grad--> 4.1857997530314606e-06


        microbatches: 100%|██████████| 1/1 [00:00<00:00,  1.06it/s]e[A

If regularization is used (alpha=1, I also tries alpha=0.1, results are basically similar):
BATCH: 0%| | 0/3086 [00:00<?, ?it/s]

        microbatches:   0%|          | 0/1 [00:00<?, ?it/s]e[A
      loss norm-> 1.0168471 ; groundtruth maxmin--> 0.0006271558 0.0003031217 ; jac_norm maxmin 0.67917717 4.0782093e-09 ;input nan number-->  0 ;OUT nan number--> 0 ;max grad--> 0.0021924262400716543


        microbatches: 100%|██████████| 1/1 [00:01<00:00,  1.81s/it]e[A

                                                                   e[A
BATCH:   0%|          | 1/3086 [00:02<2:20:00,  2.72s/it]

        microbatches:   0%|          | 0/1 [00:00<?, ?it/s]e[A
      loss norm-> 1.0312529 ; groundtruth maxmin--> 0.0006206568 0.00030541696 ; jac_norm maxmin 0.74734956 3.9890118e-08 ;input nan number-->  0 ;OUT nan number--> 0 ;max grad--> 0.004061009269207716


        microbatches: 100%|██████████| 1/1 [00:00<00:00,  1.10it/s]e[A

                                                                   e[A
BATCH:   0%|          | 2/3086 [00:04<1:44:01,  2.02s/it]

        microbatches:   0%|          | 0/1 [00:00<?, ?it/s]e[A
      loss norm-> 1.0002187 ; groundtruth maxmin--> 0.0006206568 0.00030476737 ; jac_norm maxmin 0.06759368 5.790319e-10 ;input nan number-->  0 ;OUT nan number--> 0 ;max grad--> 2.8704864234896377e-05


        microbatches: 100%|██████████| 1/1 [00:00<00:00,  1.08it/s]e[A

                                                                   e[A
BATCH:   0%|          | 3/3086 [00:05<1:33:04,  1.81s/it]

        microbatches:   0%|          | 0/1 [00:00<?, ?it/s]e[A
      loss norm-> 0.9999997 ; groundtruth maxmin--> 0.0010546517 0.00030287757 ; jac_norm maxmin 0.002307556 1.724971e-09 ;input nan number-->  0 ;OUT nan number--> 0 ;max grad--> 3.267977222165541e-09


        microbatches: 100%|██████████| 1/1 [00:00<00:00,  1.67it/s]e[A

                                                                   e[A
BATCH:   0%|          | 4/3086 [00:07<1:24:13,  1.64s/it]

        microbatches:   0%|          | 0/1 [00:00<?, ?it/s]e[A
      loss norm-> 1.0875233 ; groundtruth maxmin--> 0.00097294734 0.00030541696 ; jac_norm maxmin 1.1377246 3.1125098e-09 ;input nan number-->  0 ;OUT nan number--> 0 ;max grad--> 0.011472181417047977


        microbatches: 100%|██████████| 1/1 [00:01<00:00,  1.90s/it]e[A

                                                                   e[A
BATCH:   0%|          | 5/3086 [00:09<1:39:59,  1.95s/it]

        microbatches:   0%|          | 0/1 [00:00<?, ?it/s]e[A
      loss norm-> 1.0109061 ; groundtruth maxmin--> 0.0011556828 0.00031172478 ; jac_norm maxmin 0.43857938 5.6562624e-07 ;input nan number-->  0 ;OUT nan number--> 0 ;max grad--> 0.001391432830132544


        microbatches: 100%|██████████| 1/1 [00:00<00:00,  1.07it/s]e[A

                                                                   e[A