# Self-defined function backward

Hello everyone, I have a question in backward function

``````import torch
import math

def ISAR_imaging(time_start, vextex, normal, horizon, rotation_state, device):
c = torch.tensor([299792458.0], device=device)
Tcoh = torch.tensor([5], device=device)
PRF = torch.tensor([20], device=device)
fc = torch.tensor([9.7e9], device=device)
Tp = torch.tensor([5e-4], device=device)
B = torch.tensor([30e6], device=device)
k = B / Tp
fs = 1.2 * B
Tr = 1 / PRF
Na = torch.round(PRF * Tcoh)
Na = Na + torch.remainder(Na, 2)
Tcoh = Na * Tr
lambda0 = c / fc

P = rotation_state[0]
gamma = rotation_state[1]
phi = rotation_state[2]

axis_x = torch.sin(gamma) * torch.cos(phi)
axis_y = torch.sin(gamma) * torch.sin(phi)
axis_z = torch.cos(gamma)

Ry = torch.tensor([
[torch.cos(gamma), 0, torch.sin(gamma)],
[0, 1, 0],
[-torch.sin(gamma), 0, torch.cos(gamma)]
], dtype=torch.float32, device=device).t()

Rz = torch.tensor([
[torch.cos(phi), -torch.sin(phi), 0],
[torch.sin(phi), torch.cos(phi), 0],
[0, 0, 1]
], dtype=torch.float32, device=device).t()

vextex = vextex @ Ry @ Rz
normal = normal @ Ry @ Rz

st = time_start + torch.linspace(-Tr.item() * Na.item() / 2, Tr.item() * Na.item() / 2, Na.item(), device=device)
vextex = vextex.t()
normal = normal.t()

R_box = torch.empty(Na.item(), vextex.shape[1], device=device)
sigma_box = torch.empty(Na.item(), vextex.shape[1], device=device)

for i in range(Na.item()):
theta = 1 / P * st[i] * math.pi / 1800
Rotmat = torch.tensor([
[axis_x**2 + (1 - axis_x**2) * torch.cos(theta),
axis_x * axis_y * (1 - torch.cos(theta)) - axis_z * torch.sin(theta),
axis_x * axis_z * (1 - torch.cos(theta)) + axis_y * torch.sin(theta)],
[axis_x * axis_y * (1 - torch.cos(theta)) + axis_z * torch.sin(theta),
axis_y**2 + (1 - axis_y**2) * torch.cos(theta),
axis_y * axis_z * (1 - torch.cos(theta)) - axis_x * torch.sin(theta)],
[axis_x * axis_z * (1 - torch.cos(theta)) - axis_y * torch.sin(theta),
axis_y * axis_z * (1 - torch.cos(theta)) + axis_x * torch.sin(theta),
axis_z**2 + (1 - axis_z**2) * torch.cos(theta)]
], dtype=torch.float32, device=device)

vextex_rot = Rotmat @ vextex
normal_rot = Rotmat @ normal
horizon_i = horizon[i, :]
horizon_i = horizon_i / torch.norm(horizon_i)
R = horizon_i.t() @ vextex_rot
R_box[i, :] = R
sigma = torch.clamp(-horizon_i.t() @ normal_rot, min=0).squeeze(0)
sigma_box[i, :] = sigma

Doppler_domain = -2 * (R_box[Na.item() - 1, :] - R_box[0, :]) / Tcoh / lambda0
Range_domain = (R_box[Na.item() - 1, :] + R_box[0, :]) / 2
sigma = torch.mean(sigma_box, dim=0)

fs_range_window = torch.tensor([10.0], device=device)
fs_doppler_window = torch.tensor([10.0], device=device)
N_range_window = torch.tensor([100.0], device=device)
N_doppler_window = torch.tensor([100.0], device=device)
range_res = c / 2 / B
theta = 2 * math.pi * Tcoh / P / 3600
doppler_res = lambda0 / 2 / theta

Range_map, Doppler_map = torch.meshgrid(
torch.linspace(-50, 50, 100, device=device), torch.linspace(-6, 6, 100, device=device), indexing='xy')
z = torch.zeros_like(Range_map, device=device)

for i in range(vextex.shape[1]):
z += torch.abs(sigma[i] * sinc_windowed(1 / range_res * (Range_map - Range_domain[i]), fs_range_window, N_range_window) *
sinc_windowed(1 / doppler_res * (Doppler_map - Doppler_domain[i]), fs_doppler_window, N_doppler_window))
z = z / torch.max(z)
return z

def sinc_windowed(x, fs, N):

def b_window(x, fs, N):
return ((x * fs < N / 2) & (x * fs > -N / 2)).float() * \
(0.42 - 0.5 * torch.cos(2 * math.pi * (x * fs + N / 2) / (N - 1)) + 0.08 * torch.cos(4 * math.pi * (x * fs + N / 2) / (N - 1)))

for i in range(len(x)):
x_i = x.clone()
x_i[i] += h
grad[i] = (f(x_i) - f(x)) / h
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
time_start = torch.tensor([0.0], device=device)
vextex = torch.randn(100, 3, device=device)
normal = torch.randn(100, 3, device=device)
horizon = torch.randn(100, 3, device=device)
rotation_state = torch.tensor([0.1, 0.1, 0.1], requires_grad=True, device=device)

for _ in range(20):
z = ISAR_imaging(time_start, vextex, normal, horizon, rotation_state, device)
loss = torch.sum(z)
loss.backward()
optimizer.step()

def loss_func(rotation_state):