SGD optimizer doesn't optimize passed parameter

Hello all!
I am new to ML and I participate in a cool project, in which we need to find the best noise to be projected onto a specific object that is found in pictures taken from a drone’s trajectory, so that the distance between the drone’s original estimated position and the “attacked” estimated position(the one estimated using the noise projection onto the object) is maximized.
We use PoseResNet for the pose(and position) estimation.
We use MSELoss to measure the distance(to be precised, the negative distance) in order to complete the optimization problem.
For some reason, we can’t get the optimizer to optimize the noise, and would really appreciate your help.
Below are snippets from the code - I hope it is detailed well enough and not over-detailed:

import kornia.geometry as kgm
import torch
import cv2
from PoseResNet import PoseResNet

pose_net = PoseResNet(num_layers=18, pretrained=True).to(device)
pose_net.eval()  # We're in evaluation mode of PoseResNet!!
weights_path = os.path.join(os.getcwd(), 'exp_pose_model_best.pth.tar')
weights = torch.load(weights_path, map_location=self.device)
pose_net.load_state_dict(weights['state_dict'], strict=False)

criterion = torch.nn.MSELoss(reduction='sum')

def warp_tensor(im_tensor, dst_points):
    # tensor version of "warp_image_by_obj_pos"
    img = im_tensor.float().unsqueeze(dim=0)  # BxCxHxW
    # points order from upper left clockwise.
    # the source points are the image corners.
    _, h, w = im_tensor.shape  # src size
    points_src = torch.FloatTensor([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]]).unsqueeze(0)
    # the destination points are the region to crop corners
    points_dst = torch.FloatTensor(np.array(dst_points)).reshape((1, -1, 2))
    # compute perspective transform
    M = kgm.get_perspective_transform(points_src, points_dst)
    # warp the original tensor by the found transform
    img_warp = kgm.warp_perspective(img, M, dsize=(h, w), align_corners=True)
    return img_warp
  
def load_noisy_image(self, traj_dir_path: str, i: int, noise_):
    img_path = os.path.abspath(os.path.join(traj_dir_path, str(i)+'.png'))
    img = cv2.imread(img_path).astype(np.float32)
    target_coords, is_obj_visible = self.get_img_coords_of_target_world_coords(traj_dir_path=traj_dir_path, i=i) 

    noise_warped = warp_tensor(noise_, target_coords)
    mask = warp_tensor(torch.ones_like(noise_), target_coords).ge(0.5)  # create a mask
    a = ku.image_to_tensor(img / 255.).unsqueeze(0).float()  # .reshape(1, 3, self.img_height, self.img_width)
    a[mask] *= 1 - self.eps  # set the target in image using the mask
    img_warped = a + self.eps * noise_warped
    img_warped = torch.clamp(img_warped, 0., 1.)

    tensor_img1_attacked = img_warped.detach()
    # tensor_display = tensor_img1_attacked.numpy().squeeze(0).transpose(1, 2, 0)
    return tensor_img1_attacked
      
    def generate_attacked_poses(self, traj_dir_path: str, traj_num: str, noise_, attacked_imgs_save_dir, epoch_num: int = -1):
        epoch_num = str(epoch_num)
        images_num = len(os.listdir(traj_dir_path)) - 1  # we don't count the log file

        with open(os.path.join(traj_dir_path, "log.json"), 'r') as log_file:
            log_fp = json.load(log_file)

        poses_num = len(log_fp) - 1  # we don't take the info of the object
        attacked_poses = torch.empty((poses_num, 3, 4))

        pos_x0 = torch.tensor(list(log_fp[1]['position'].values())).unsqueeze(1)
        R_mat_x0 = torch.tensor(quaternion2rotation_matrix(log_fp[1]['orientation']))
        global_pose_attacked = torch.cat((R_mat_x0, pos_x0), dim=1)
        global_pose_attacked = torch.cat((global_pose_attacked, torch.tensor([0, 0, 0, 1]).unsqueeze(0)), dim=0).double()
        global_pose_attacked.requires_grad_(True)
        attacked_poses[0] = global_pose_attacked[0:3, :]
        attacked_img1 = self.load_noisy_image(traj_dir_path=traj_dir_path, i=0, noise_=noise_)
        for i in range(images_num - 1):
            attacked_img2 = self.load_noisy_image(traj_dir_path=traj_dir_path, i=i+1, noise_=noise_)
            with torch.no_grad():
                pose = pose_net(attacked_img1, attacked_img2)
            pose_mat = pose_vec2mat(pose).squeeze(0).cpu().detach().double()
            pose_mat = torch.vstack((pose_mat, torch.tensor([0, 0, 0, 1])))

            global_pose_attacked = torch.inverse(pose_mat) @ global_pose_attacked
            attacked_poses[i+1] = global_pose_attacked[0:3, :].detach()
            # update
            attacked_img1 = attacked_img2
        return attacked_poses
      
    def fit(self, indices, fold_fit_dir):
        noise = torch.rand(3, self.patch_height, self.patch_width, device=self.device, requires_grad=True)
        optimizer = torch.optim.SGD([noise], lr=self.lr, momentum=self.momentum)
        traj_loss = 0
        for epoch_num in range(self.epochs_num):
            noise_ = resize2d(noise, (self.img_height, self.img_width))
            optimizer.zero_grad()
            for i, traj_num in enumerate(indices):
                curr_traj_dir = os.path.join(self.dataset_save_location, str(traj_num))
                curr_traj_generated_poses = self.generate_poses(curr_traj_dir)
                save_dir = fold_fit_dir / str(traj_num)
                curr_traj_attacked_poses = self.generate_attacked_poses(curr_traj_dir, str(traj_num), noise_, save_dir, str(epoch_num))
                curr_traj_attacked_poses.requires_grad_(True)
                traj_loss += -1 * self.loss_factor * self.criterion(curr_traj_attacked_poses[-1][:, -1], curr_traj_generated_poses[-1][:, -1]).item()
            epoch_mean_loss = torch.tensor([traj_loss / len(indices)], requires_grad=True)
            epoch_mean_loss.backward()
            optimizer.step()
            traj_loss = 0

Thanks!

Hi Asaf!

The various calls to .detach() along the code path that generates
attacked_poses “breaks the computation graph” and setting
requires_grad = True thereafter doesn’t repair the damage. That
is to say, no gradients will be backpropagated through the .detach()
calls, so gradient descent will not be able to optimize noise.

May I suggest that you look at pytorch’s autograd documentation,
in particular the Computational Graph section of pytorch’s autograd
tutorial?

Best.

K. Frank

Hi Frank,
Thanks for your answer!
I have disabled all the .detach() calls in our code, and the problem persists.
Do you have another idea maybe?

Thanks!