Loss is nan when set create_graph is True

I want to compute the gradients of the output with respect to the input, using autograd.grad with create_graph=True, however the output of model logit_pred_vals is nan after several epochs. If I set create_graph=False and there is no problem with loss. Does anyone has same the problem?

input_grads= torch.autograd.grad(outputs=logit_pred_vals, inputs=x_in, grad_outputs=torch.ones_like(logit_pred_vals).to(self.device),
retain_graph=True, create_graph=True)[0]

Epoch 1: Training Loss: 1.1912/ Training Correctness Loss: 1.0095 / Training Prior Loss: 0.1816/ Val Correctness Loss: 0.8354 / Val F1-Score: 0.518459 / Training Time: 3.411540 / Inference Time: 0.118823
Epoch 2: Training Loss: 1.0043/ Training Correctness Loss: 0.8510 / Training Prior Loss: 0.1534/ Val Correctness Loss: 0.7673 / Val F1-Score: 0.518459 / Training Time: 3.372341 / Inference Time: 0.119534
Epoch 3: Training Loss: 0.9447/ Training Correctness Loss: 0.7960 / Training Prior Loss: 0.1487/ Val Correctness Loss: 0.7258 / Val F1-Score: 0.518459 / Training Time: 3.380471 / Inference Time: 0.121105
Epoch 4: Training Loss: 0.9273/ Training Correctness Loss: 0.7808 / Training Prior Loss: 0.1465/ Val Correctness Loss: 0.6367 / Val F1-Score: 0.518459 / Training Time: 3.388045 / Inference Time: 0.120462
Epoch 5: Training Loss: nan/ Training Correctness Loss: nan / Training Prior Loss: nan/ Val Correctness Loss: nan / Val F1-Score: 0.518459 / Training Time: 3.394484 / Inference Time: 0.114097
Epoch 6: Training Loss: nan/ Training Correctness Loss: nan / Training Prior Loss: nan/ Val Correctness Loss: nan / Val F1-Score: 0.518459 / Training Time: 3.223722 / Inference Time: 0.113659
Epoch 7: Training Loss: nan/ Training Correctness Loss: nan / Training Prior Loss: nan/ Val Correctness Loss: nan / Val F1-Score: 0.518459 / Training Time: 3.225215 / Inference Time: 0.114736

Could you post a minimal and executable code snippet reproducing this issue, please?

Please training code below and let me know if you need more information

for epoch in range(self.epochs):

    if torch.cuda.is_available:
        torch.cuda.empty_cache()
    start = time()
    # Training steps
    self.model.train()
    training_loss = 0.0
    correctness_loss = 0.0
    prior_att_loss = 0.0
    for batch_num, batch in enumerate(dataset_train):

        x_in, y_in = batch
        optimizer.zero_grad()
        if use_prior:
            x_in.requires_grad = True
            logit_pred_vals = self.model(x_in)
            input_grads = torch.autograd.grad(outputs=logit_pred_vals, inputs=x_in,
                                              grad_outputs=torch.ones_like(logit_pred_vals).to(self.device),
                                              retain_graph=True, create_graph=True)[0]

            input_grads = input_grads * x_in
            x_in.requires_grad = False
            loss_1 = loss_fn(logit_pred_vals, y_in)
            loss_2 = self.fourier_att_prior_loss(y_in, input_grads, 50, 0.2, 3)
            loss = loss_1 + loss_2
            correctness_loss += loss_1.item()
            prior_att_loss += loss_2.item()
            training_loss += loss.item()
        else:
            logit_pred_vals = model(x_in)
            loss = loss_fn(logit_pred_vals, y_in)
            training_loss += loss.item()

        loss.backward()  # Compute gradient
        optimizer.step()  # Update weights through backprop

    end = time()

    total_time = end - start
    correctness_loss /= len(dataset_train)
    prior_att_loss /= len(dataset_train)
    training_loss /= len(dataset_train)

input_grads is used to calculate secondary loss beside normal loss that uses torch.nn.CrossEntropyLoss().

I suspect that gradient exploded when setting create_graph = True, so output of model is all nan