# Danskin's theorem: no_grad on auxiliary optimization throws off gradient of function input

Hi,

I am running into an issue with some gradients in a network I am training, and I think I have narrowed it down to a conceptual issue on my end. I have a layer `fn` which solves a convex optimization problem (Sinkhorn’s algorithm), which basically involves constructing two auxiliary variables `u, v` in a loop and then combining them together with the function input `C` to form the final answer. By Danskin’s theorem, the gradient of this layer should be obtainable using the optimal values of the auxiliary variables, without having to back-propagate through the optimization steps themselves.

Unfortunately, when I check the gradients with `torch.autograd.gradcheck`, it appears this does not work out in practice, despite the optimization objective converging correctly.

A minimal example is as follows:

``````import torch

NUM_ITERS = 999  # for example

# Compute u and v without gradients
b, n, m = C.shape
u = torch.zeros(b, n, dtype=C.dtype, device=C.device, requires_grad=False)
v = torch.zeros(b, m, dtype=C.dtype, device=C.device, requires_grad=False)

for i in range(NUM_ITERS):
if i % 2 == 0:
u.sub_(torch.logsumexp(-C + u.unsqueeze(-1) + v.unsqueeze(-2), dim=-1))
else:
v.sub_(torch.logsumexp(-C + u.unsqueeze(-1) + v.unsqueeze(-2), dim=-2))
M = -C + u.unsqueeze(-1) + v.unsqueeze(-2)
M = torch.exp(M)
return M

def fn(C):
# Compute u and v with gradients
b, n, m = C.shape
u = torch.zeros(b, n, dtype=C.dtype, device=C.device, requires_grad=False)
v = torch.zeros(b, m, dtype=C.dtype, device=C.device, requires_grad=False)
for i in range(NUM_ITERS):
if i % 2 == 0:
u.sub_(torch.logsumexp(-C + u.unsqueeze(-1) + v.unsqueeze(-2), dim=-1))
else:
v.sub_(torch.logsumexp(-C + u.unsqueeze(-1) + v.unsqueeze(-2), dim=-2))
M = -C + u.unsqueeze(-1) + v.unsqueeze(-2)
M = torch.exp(M)
return M

if __name__ == "__main__":
C = torch.randn(2, 10, 5, device="cuda", dtype=torch.float64, requires_grad=True)

# Check forward passes are the same
As you can see, the `fn_with_no_grad` version detaches `u` and `v` and then essentially treats them as constant during the backward pass (in accordance with Danskin’s theorem).
Could anyone help me demystify what is happening here? I am confused as to why `u` and `v` would even contribute to the gradient here, given that they are initialized with `require_grad=False`. Has anyone had success applying Danskin’s theorem in conjunction with autograd?