Orthogonal gradient to input image

Hi,

assuming I have computed the gradient of a loss of a neural network wrt an image of say shape (3, 224, 224), how do I then compute the corresponding orthogonal gradient?

Thanks!

Hi Alex!

Think of your image and its gradient as both being vectors of length 3 * 224 * 224.

gradient - dot (gradient, image) * image / |image|**2 is the component
of gradient that is orthogonal to image (where dot() indicates the “dot” or inner
product between the two vectors gradient and image and |image| is the norm
or length of the vector image).

Best.

K. Frank

Hi Frank, thanks for your reply!
Let me write it in psudocode to make sure I understand:

import torch

# define random image
img = torch.rand(3, 224, 224)
img = img.requires_grad_(True)

# compute the loss
loss = Net(img)
loss.backward()

# store grad
img_grad = img.grad

# view(-1) to transform into vectors of size 3 * 224 * 224
orth_vector = img_grad.view(-1) - \
              img_grad.view(-1) @ img.view(-1) * img.view(-1) / (img.view(-1).norm() ** 2)

# check they are orthogonal
assert(orth_vector @ img_grad == 0)

I checked and the last assertion isn’t true for something like that. What is wrong in my pseudocode?

I had this idea:

import torch

# define random image
img = torch.rand(1, 3, 224, 224)
img = img.requires_grad_(True)

# compute the loss
loss = Net(img)
loss.backward()

# store grad
img_grad = img.grad

orthonormal_mtx = []
# loop through each channel in the image, so we orthonormalize for 
# each channel independently
for channel in range(3):
   # here I assume we have defined a function that performs
   # gram-schmidt orthonormalization
   orth_matrix = gram_schmidt(img_grad[channel]).unsqueeze(0) # shape (1, 224, 224)

# I can then check that each channel component is orthogonal 
# to the original gradient, for example:
# img_grad[0] @ orth_matrix[0].T will give an Identity matrix
orth_matrix = torch.cat(orth_matrix)  # shape (3, 224, 224)

Cheers!

Hi Alex!

If I understand correctly what you are asking, the problem is that @ is giving you
matrix multiplication, rather than a scalar dot product.

Consider:

>>> import torch
>>> torch.__version__
'2.3.0'
>>> imag = torch.randn (3, 224, 224)
>>> grad = torch.randn (3, 224, 224)
>>> orth = grad - torch.dot (imag.flatten(), grad.flatten()) * imag / imag.norm()**2
>>> (orth * imag).sum()
tensor(0.0006)
>>> ((orth / orth.norm()) * (imag / imag.norm())).sum()
tensor(3.2596e-09)

(which is orthogonality up to expected round-off error).

Best.

K. Frank

1 Like