I have been implementing Deep Hashing via Discrepancy Minimization for quite some time now but I am stuck because the parameters of the model (modified AlexNet) are not updating. There seems to be a problem with the gradient computation as when I try to print loss.grad after calling loss.backward, I get None. H has requires_grad=True, D_hat has requires_grad=False, delta has requires_grad=False, lambda1, lambda2 are scalars.
class loss_with_H(Function):
@staticmethod
def forward(ctx, H, D_hat, delta, lambda1, lambda2):
if(lambda1 == 0 and lambda2 == 0):
delta = torch.zeros_like(H)
temp = D_hat.t() + D_hat
ctx.save_for_backward(temp, H, delta)
ctx.lambda1 = lambda1
L = torch.trace(torch.mm(torch.mm(H.t(), D_hat) + torch.mm(lambda1 * delta.t(), temp), H) + lambda2 * torch.mm(delta.t(), torch.mm(D_hat, delta)))
return L
@staticmethod
def backward(ctx, grad_out):
temp, H, delta = ctx.saved_tensors
lambda1 = ctx.lambda1
grad_H = grad_D_hat = grad_delta = grad_lambda1 = grad_lambda2 = None
grad_H = grad_out.clone()
grad_H = torch.mm(temp, H + lambda1 * delta)
return grad_H * grad_out, grad_D_hat, grad_delta, grad_lambda1, grad_lambda2
I am updating delta as follows
def update_delta(D_hat, H, delta, lambda1, lambda2):
with torch.no_grad():
temp = D_hat.t() + D_hat
grad_delta = torch.mm(temp, lambda1 * H + lambda2 * delta)
delta = -torch.sign(grad_delta) - H
return delta