I have to implement a custom loss function to estimate some weights using the the total variance of a set of images. The problem is that my function requires using np.gradient(), and that breaks the computation graph. Does anyone know how to solve this?
Here’s whatI have so far:
class TotalVariation(Function):
@staticmethod
def forward(ctx, projections, meanFF, FF, DF, x):
FF_eff = torch.zeros((FF.shape[1], FF.shape[2]))
for i in range(len(FF)):
FF_eff = FF_eff + x[i] * FF[i]
logCorProj=(projections-DF)/(meanFF+FF_eff)*torch.mean(
torch.flatten(meanFF)+torch.flatten(FF_eff))
# tmp = torch.mean(logCorProj)
# return tmp
# pytorch doesn't have a np.gradient equivalent
logCorProj = logCorProj.clone().detach().numpy()
Gx, Gy = np.gradient(logCorProj)
Gx, Gy = torch.tensor(Gx), torch.tensor(Gy)
mag = (Gx**2 + Gy**2)**(1/2)
cost = torch.sum(torch.flatten(torch.from_numpy(mag)))
ctx.save_for_backward(cost)
return cost
@staticmethod
def backward(ctx, grad_outputs):
result, = ctx.saved_tensors
return result, None, None, None, None