Hello everyone, I have a question in backward function
import torch
import math
def ISAR_imaging(time_start, vextex, normal, horizon, rotation_state, device):
c = torch.tensor([299792458.0], device=device)
Tcoh = torch.tensor([5], device=device)
PRF = torch.tensor([20], device=device)
fc = torch.tensor([9.7e9], device=device)
Tp = torch.tensor([5e-4], device=device)
B = torch.tensor([30e6], device=device)
k = B / Tp
fs = 1.2 * B
Tr = 1 / PRF
Na = torch.round(PRF * Tcoh)
Na = Na + torch.remainder(Na, 2)
Tcoh = Na * Tr
lambda0 = c / fc
P = rotation_state[0]
gamma = rotation_state[1]
phi = rotation_state[2]
axis_x = torch.sin(gamma) * torch.cos(phi)
axis_y = torch.sin(gamma) * torch.sin(phi)
axis_z = torch.cos(gamma)
Ry = torch.tensor([
[torch.cos(gamma), 0, torch.sin(gamma)],
[0, 1, 0],
[-torch.sin(gamma), 0, torch.cos(gamma)]
], dtype=torch.float32, device=device).t()
Rz = torch.tensor([
[torch.cos(phi), -torch.sin(phi), 0],
[torch.sin(phi), torch.cos(phi), 0],
[0, 0, 1]
], dtype=torch.float32, device=device).t()
vextex = vextex @ Ry @ Rz
normal = normal @ Ry @ Rz
st = time_start + torch.linspace(-Tr.item() * Na.item() / 2, Tr.item() * Na.item() / 2, Na.item(), device=device)
vextex = vextex.t()
normal = normal.t()
R_box = torch.empty(Na.item(), vextex.shape[1], device=device)
sigma_box = torch.empty(Na.item(), vextex.shape[1], device=device)
for i in range(Na.item()):
theta = 1 / P * st[i] * math.pi / 1800
Rotmat = torch.tensor([
[axis_x**2 + (1 - axis_x**2) * torch.cos(theta),
axis_x * axis_y * (1 - torch.cos(theta)) - axis_z * torch.sin(theta),
axis_x * axis_z * (1 - torch.cos(theta)) + axis_y * torch.sin(theta)],
[axis_x * axis_y * (1 - torch.cos(theta)) + axis_z * torch.sin(theta),
axis_y**2 + (1 - axis_y**2) * torch.cos(theta),
axis_y * axis_z * (1 - torch.cos(theta)) - axis_x * torch.sin(theta)],
[axis_x * axis_z * (1 - torch.cos(theta)) - axis_y * torch.sin(theta),
axis_y * axis_z * (1 - torch.cos(theta)) + axis_x * torch.sin(theta),
axis_z**2 + (1 - axis_z**2) * torch.cos(theta)]
], dtype=torch.float32, device=device)
vextex_rot = Rotmat @ vextex
normal_rot = Rotmat @ normal
horizon_i = horizon[i, :]
horizon_i = horizon_i / torch.norm(horizon_i)
R = horizon_i.t() @ vextex_rot
R_box[i, :] = R
sigma = torch.clamp(-horizon_i.t() @ normal_rot, min=0).squeeze(0)
sigma_box[i, :] = sigma
Doppler_domain = -2 * (R_box[Na.item() - 1, :] - R_box[0, :]) / Tcoh / lambda0
Range_domain = (R_box[Na.item() - 1, :] + R_box[0, :]) / 2
sigma = torch.mean(sigma_box, dim=0)
fs_range_window = torch.tensor([10.0], device=device)
fs_doppler_window = torch.tensor([10.0], device=device)
N_range_window = torch.tensor([100.0], device=device)
N_doppler_window = torch.tensor([100.0], device=device)
range_res = c / 2 / B
theta = 2 * math.pi * Tcoh / P / 3600
doppler_res = lambda0 / 2 / theta
Range_map, Doppler_map = torch.meshgrid(
torch.linspace(-50, 50, 100, device=device), torch.linspace(-6, 6, 100, device=device), indexing='xy')
z = torch.zeros_like(Range_map, device=device)
for i in range(vextex.shape[1]):
z += torch.abs(sigma[i] * sinc_windowed(1 / range_res * (Range_map - Range_domain[i]), fs_range_window, N_range_window) *
sinc_windowed(1 / doppler_res * (Doppler_map - Doppler_domain[i]), fs_doppler_window, N_doppler_window))
z = z / torch.max(z)
return z
def sinc_windowed(x, fs, N):
return torch.sinc(x) * b_window(x, fs, N)
def b_window(x, fs, N):
return ((x * fs < N / 2) & (x * fs > -N / 2)).float() * \
(0.42 - 0.5 * torch.cos(2 * math.pi * (x * fs + N / 2) / (N - 1)) + 0.08 * torch.cos(4 * math.pi * (x * fs + N / 2) / (N - 1)))
def numerical_gradient(f, x, h=1e-5):
grad = torch.zeros_like(x)
for i in range(len(x)):
x_i = x.clone()
x_i[i] += h
grad[i] = (f(x_i) - f(x)) / h
return grad
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
time_start = torch.tensor([0.0], device=device)
vextex = torch.randn(100, 3, device=device)
normal = torch.randn(100, 3, device=device)
horizon = torch.randn(100, 3, device=device)
rotation_state = torch.tensor([0.1, 0.1, 0.1], requires_grad=True, device=device)
optimizer = torch.optim.Adam([rotation_state], lr=0.01)
for _ in range(20):
optimizer.zero_grad()
z = ISAR_imaging(time_start, vextex, normal, horizon, rotation_state, device)
loss = torch.sum(z)
# torch.autograd.gradcheck
loss.backward()
optimizer.step()
def loss_func(rotation_state):
return torch.sum(ISAR_imaging(time_start, vextex, normal, horizon, rotation_state, device))
h = 1e-5
grad_numerical = numerical_gradient(loss_func, rotation_state.clone().detach(), h)
print(f'Loss: {loss.item()}, Gradients: {rotation_state.grad}, Numerical Gradients: {grad_numerical}')
I tried to calculate the numerical gradient and the automatic gradient of rotation_state in function ISAR_imaging, and found that the values are not the same at all, and some variables that should have a gradient do not have one, I wish to know how to solve this problem