I am training a neural network with a loss function denoted as L. I aim to constrain the network’s outputs (z) such that they lie on a hypersphere of dimension d.
To maintain a valid training framework, I need the backpropagated gradients to respect this constraint, meaning they should be tangent to the hypersphere at the corresponding points.
To achieve this, I plan to project the gradients ∂L/∂z, where z is the network’s output, onto the tangent plane of the hypersphere. Based on my understanding, this projection can be computed as:
g_tangent=g−(g⋅z_normalized) z_normalized,
where g=∂L/∂z and z_normalized is the normalized version of z.
However, I’m having difficulty accessing these gradients, modifying them, and ensuring the loss is properly backpropagated after applying this projection.
I’d greatly appreciate any feedback, suggestions, or guidance on implementing this process. Thank you so much for your help!
I’m not sure I follow what you are asking, but if I understand your use case, it should
suffice to project your z onto the hypersphere and then compute your loss function
using the values of the projected z. The gradients with respect to the unprojected z
will then naturally be tangent to the hypersphere (because the objective function will
now be independent of any changes in z normal to the hypersphere, as any such
changes will have been projected away).
Here is an illustrative script:
import torch
print (torch.__version__)
_ = torch.manual_seed (2024)
t = torch.tensor ([2.0, 2.0]) # two-dimensional target point
x = torch.randn (5, 2, requires_grad = True) # batch of five two-dimensional starting points
print ('x ...')
print (x)
y = x / (x**2).sum (dim = 1, keepdim = True).sqrt() # project x onto unit circle
print ('y = x projected onto unit circle ...')
print (y)
torch.nn.MSELoss (reduction = 'sum') (x, t).backward()
print ('x.grad (no projection) ...')
print (x.grad)
print ('x.grad <dot> x ...')
print ((x * x.grad).sum (dim = 1)) # x.grad not tangent to unit circle
x.grad = None
torch.nn.MSELoss (reduction = 'sum') (y, t).backward()
print ('x.grad (with projection) ...')
print (x.grad)
print ('x.grad <dot> x ...')
print ((x * x.grad).sum (dim = 1)) # this version of x.grad is tangent to unit circle