Self-defined function backward

Hello everyone, I have a question in backward function

import torch
import math

def ISAR_imaging(time_start, vextex, normal, horizon, rotation_state, device):
    c = torch.tensor([299792458.0], device=device)
    Tcoh = torch.tensor([5], device=device)
    PRF = torch.tensor([20], device=device)
    fc = torch.tensor([9.7e9], device=device)
    Tp = torch.tensor([5e-4], device=device)
    B = torch.tensor([30e6], device=device)
    k = B / Tp
    fs = 1.2 * B
    Tr = 1 / PRF
    Na = torch.round(PRF * Tcoh)
    Na = Na + torch.remainder(Na, 2)
    Tcoh = Na * Tr
    lambda0 = c / fc

    P = rotation_state[0]
    gamma = rotation_state[1]
    phi = rotation_state[2]

    axis_x = torch.sin(gamma) * torch.cos(phi)
    axis_y = torch.sin(gamma) * torch.sin(phi)
    axis_z = torch.cos(gamma)
    
    Ry = torch.tensor([
        [torch.cos(gamma), 0, torch.sin(gamma)],
        [0, 1, 0],
        [-torch.sin(gamma), 0, torch.cos(gamma)]
    ], dtype=torch.float32, device=device).t()

    Rz = torch.tensor([
        [torch.cos(phi), -torch.sin(phi), 0],
        [torch.sin(phi), torch.cos(phi), 0],
        [0, 0, 1]
    ], dtype=torch.float32, device=device).t()
    
    vextex = vextex @ Ry @ Rz
    normal = normal @ Ry @ Rz

    st = time_start + torch.linspace(-Tr.item() * Na.item() / 2, Tr.item() * Na.item() / 2, Na.item(), device=device)
    vextex = vextex.t()
    normal = normal.t()

    R_box = torch.empty(Na.item(), vextex.shape[1], device=device)
    sigma_box = torch.empty(Na.item(), vextex.shape[1], device=device)

    for i in range(Na.item()):
        theta = 1 / P * st[i] * math.pi / 1800
        Rotmat = torch.tensor([
            [axis_x**2 + (1 - axis_x**2) * torch.cos(theta),
             axis_x * axis_y * (1 - torch.cos(theta)) - axis_z * torch.sin(theta),
             axis_x * axis_z * (1 - torch.cos(theta)) + axis_y * torch.sin(theta)],
            [axis_x * axis_y * (1 - torch.cos(theta)) + axis_z * torch.sin(theta),
             axis_y**2 + (1 - axis_y**2) * torch.cos(theta),
             axis_y * axis_z * (1 - torch.cos(theta)) - axis_x * torch.sin(theta)],
            [axis_x * axis_z * (1 - torch.cos(theta)) - axis_y * torch.sin(theta),
             axis_y * axis_z * (1 - torch.cos(theta)) + axis_x * torch.sin(theta),
             axis_z**2 + (1 - axis_z**2) * torch.cos(theta)]
        ], dtype=torch.float32, device=device)
        
        vextex_rot = Rotmat @ vextex
        normal_rot = Rotmat @ normal
        horizon_i = horizon[i, :]
        horizon_i = horizon_i / torch.norm(horizon_i)
        R = horizon_i.t() @ vextex_rot
        R_box[i, :] = R
        sigma = torch.clamp(-horizon_i.t() @ normal_rot, min=0).squeeze(0)
        sigma_box[i, :] = sigma

    Doppler_domain = -2 * (R_box[Na.item() - 1, :] - R_box[0, :]) / Tcoh / lambda0
    Range_domain = (R_box[Na.item() - 1, :] + R_box[0, :]) / 2
    sigma = torch.mean(sigma_box, dim=0)

    fs_range_window = torch.tensor([10.0], device=device)
    fs_doppler_window = torch.tensor([10.0], device=device)
    N_range_window = torch.tensor([100.0], device=device)
    N_doppler_window = torch.tensor([100.0], device=device)
    range_res = c / 2 / B
    theta = 2 * math.pi * Tcoh / P / 3600
    doppler_res = lambda0 / 2 / theta
    
    Range_map, Doppler_map = torch.meshgrid(
        torch.linspace(-50, 50, 100, device=device), torch.linspace(-6, 6, 100, device=device), indexing='xy')
    z = torch.zeros_like(Range_map, device=device)
    
    for i in range(vextex.shape[1]):
        z += torch.abs(sigma[i] * sinc_windowed(1 / range_res * (Range_map - Range_domain[i]), fs_range_window, N_range_window) *
                      sinc_windowed(1 / doppler_res * (Doppler_map - Doppler_domain[i]), fs_doppler_window, N_doppler_window))
    z = z / torch.max(z)
    return z

def sinc_windowed(x, fs, N):
    return torch.sinc(x) * b_window(x, fs, N)

def b_window(x, fs, N):
    return ((x * fs < N / 2) & (x * fs > -N / 2)).float() * \
           (0.42 - 0.5 * torch.cos(2 * math.pi * (x * fs + N / 2) / (N - 1)) + 0.08 * torch.cos(4 * math.pi * (x * fs + N / 2) / (N - 1)))

def numerical_gradient(f, x, h=1e-5):
    grad = torch.zeros_like(x)
    for i in range(len(x)):
        x_i = x.clone()
        x_i[i] += h
        grad[i] = (f(x_i) - f(x)) / h
    return grad
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
time_start = torch.tensor([0.0], device=device)
vextex = torch.randn(100, 3, device=device)
normal = torch.randn(100, 3, device=device)
horizon = torch.randn(100, 3, device=device)
rotation_state = torch.tensor([0.1, 0.1, 0.1], requires_grad=True, device=device)
optimizer = torch.optim.Adam([rotation_state], lr=0.01)

for _ in range(20):
    optimizer.zero_grad()
    z = ISAR_imaging(time_start, vextex, normal, horizon, rotation_state, device)
    loss = torch.sum(z)
    # torch.autograd.gradcheck
    loss.backward()
    optimizer.step()

    def loss_func(rotation_state):
        return torch.sum(ISAR_imaging(time_start, vextex, normal, horizon, rotation_state, device))

    h = 1e-5
    grad_numerical = numerical_gradient(loss_func, rotation_state.clone().detach(), h)
    print(f'Loss: {loss.item()}, Gradients: {rotation_state.grad}, Numerical Gradients: {grad_numerical}')


I tried to calculate the numerical gradient and the automatic gradient of rotation_state in function ISAR_imaging, and found that the values are not the same at all, and some variables that should have a gradient do not have one, I wish to know how to solve this problem