Here’s my problem:
I have a loss function composed of two terms, the first of which is correct but the second results from a function that I have to integrate. This integral is (imo) intractable, but this isn’t really a problem since I have the expression for the gradient in this case (relative to the model output).
So I’d like to calculate a partial loss, composed of the known term, and then when starting the backpropagation, add the gradient of the known loss to the known gradient I had beforehand.
Here’s an example code:
class FR_loss():
def __init__(self,G,X,k):
self.X = X
self.G = G
self.k = k
self.dists = {ord_edge((u,v)) : torch.norm(self.X[v]-self.X[u])
for u in range(len(self.G)) for v in range(u+1,len(self.G))}
self.disp = self.compute_disp()
def compute_disp(self):
disp = dict()
for v in self.G:
D_v = 0
not_v = list(set(self.G.nodes()) - {v})
for u in not_v:
e_uv = (self.X[v]-self.X[u])/self.dists[ord_edge((u,v))]
D_v += -((self.k**2)/self.dists[ord_edge((u,v))])*e_uv
for u in self.G.neighbors(v):
e_uv = (self.X[v]-self.X[u])/self.dists[ord_edge((u,v))]
D_v += ((self.dists[ord_edge((u,v))]**2)/self.k)*e_uv
disp[v] = D_v
return disp
def compute_remaining_grad(self,t,lr):
grad_E_v_larger = torch.zeros((len(self.G),self.X.shape[1]))
for v in self.G:
not_v = list(set(self.G.nodes()) - {v})
D_v = self.disp[v]
if torch.norm(D_v) > t :
grad_E_v_larger[v] += t*(D_v/torch.norm(D_v))
return grad_E_v_larger*1./lr
def compute_loss_partial(self,t,lr):
E = 0
for v in self.G:
E_v = 0
not_v = list(set(self.G.nodes()) - {v})
D_v = self.disp[v]
if torch.norm(D_v) < t :
for u in not_v :
E_v += -(self.k**2)*torch.log(self.dists[ord_edge((u,v))])
for u in self.G.neighbors(v):
E_v += (1./(3*self.k))*((self.dists[ord_edge((u,v))])**3)
E += E_v
return E*1./lr
Then I run :
def train_emb_FR(model, data, k, t, optimizer):
torch.set_grad_enabled(True)
model.train()
G = nx_from_torch(data.edge_index)
Z = model(data.x)
Loss = FR_loss(G, Z,k)
loss = Loss.compute_loss_partial(t, optimizer.defaults['lr'])
remain_grad = Loss.compute_remaining_grad(t, optimizer.defaults['lr'])
optimizer.zero_grad()
part_gradients = torch.autograd.grad(loss, Z, retain_graph=True)[0]
if loss == 0 :
Z.grad = remain_grad
else :
Z.grad = part_gradients+remain_grad
loss.backward()
optimizer.step()
return loss.item()
But it does not work the way i want it to…