Hi,
assuming I have computed the gradient of a loss of a neural network wrt an image of say shape (3, 224, 224), how do I then compute the corresponding orthogonal gradient?
Thanks!
Hi,
assuming I have computed the gradient of a loss of a neural network wrt an image of say shape (3, 224, 224), how do I then compute the corresponding orthogonal gradient?
Thanks!
Hi Alex!
Think of your image and its gradient as both being vectors of length 3 * 224 * 224
.
gradient - dot (gradient, image) * image / |image|**2
is the component
of gradient
that is orthogonal to image
(where dot()
indicates the “dot” or inner
product between the two vectors gradient
and image
and |image|
is the norm
or length of the vector image
).
Best.
K. Frank
Hi Frank, thanks for your reply!
Let me write it in psudocode to make sure I understand:
import torch
# define random image
img = torch.rand(3, 224, 224)
img = img.requires_grad_(True)
# compute the loss
loss = Net(img)
loss.backward()
# store grad
img_grad = img.grad
# view(-1) to transform into vectors of size 3 * 224 * 224
orth_vector = img_grad.view(-1) - \
img_grad.view(-1) @ img.view(-1) * img.view(-1) / (img.view(-1).norm() ** 2)
# check they are orthogonal
assert(orth_vector @ img_grad == 0)
I checked and the last assertion isn’t true for something like that. What is wrong in my pseudocode?
I had this idea:
import torch
# define random image
img = torch.rand(1, 3, 224, 224)
img = img.requires_grad_(True)
# compute the loss
loss = Net(img)
loss.backward()
# store grad
img_grad = img.grad
orthonormal_mtx = []
# loop through each channel in the image, so we orthonormalize for
# each channel independently
for channel in range(3):
# here I assume we have defined a function that performs
# gram-schmidt orthonormalization
orth_matrix = gram_schmidt(img_grad[channel]).unsqueeze(0) # shape (1, 224, 224)
# I can then check that each channel component is orthogonal
# to the original gradient, for example:
# img_grad[0] @ orth_matrix[0].T will give an Identity matrix
orth_matrix = torch.cat(orth_matrix) # shape (3, 224, 224)
Cheers!
Hi Alex!
If I understand correctly what you are asking, the problem is that @
is giving you
matrix multiplication, rather than a scalar dot product.
Consider:
>>> import torch
>>> torch.__version__
'2.3.0'
>>> imag = torch.randn (3, 224, 224)
>>> grad = torch.randn (3, 224, 224)
>>> orth = grad - torch.dot (imag.flatten(), grad.flatten()) * imag / imag.norm()**2
>>> (orth * imag).sum()
tensor(0.0006)
>>> ((orth / orth.norm()) * (imag / imag.norm())).sum()
tensor(3.2596e-09)
(which is orthogonality up to expected round-off error).
Best.
K. Frank