Hi,

I am running into an issue with some gradients in a network I am training, and I think I have narrowed it down to a conceptual issue on my end. I have a layer `fn`

which solves a convex optimization problem (Sinkhorn’s algorithm), which basically involves constructing two auxiliary variables `u, v`

in a loop and then combining them together with the function input `C`

to form the final answer. By Danskin’s theorem, the gradient of this layer should be obtainable using the optimal values of the auxiliary variables, without having to back-propagate through the optimization steps themselves.

Unfortunately, when I check the gradients with `torch.autograd.gradcheck`

, it appears this does not work out in practice, despite the optimization objective converging correctly.

A minimal example is as follows:

```
import torch
NUM_ITERS = 999 # for example
def fn_with_no_grad(C):
# Compute u and v without gradients
b, n, m = C.shape
u = torch.zeros(b, n, dtype=C.dtype, device=C.device, requires_grad=False)
v = torch.zeros(b, m, dtype=C.dtype, device=C.device, requires_grad=False)
with torch.no_grad():
for i in range(NUM_ITERS):
if i % 2 == 0:
u.sub_(torch.logsumexp(-C + u.unsqueeze(-1) + v.unsqueeze(-2), dim=-1))
else:
v.sub_(torch.logsumexp(-C + u.unsqueeze(-1) + v.unsqueeze(-2), dim=-2))
M = -C + u.unsqueeze(-1) + v.unsqueeze(-2)
M = torch.exp(M)
return M
def fn(C):
# Compute u and v with gradients
b, n, m = C.shape
u = torch.zeros(b, n, dtype=C.dtype, device=C.device, requires_grad=False)
v = torch.zeros(b, m, dtype=C.dtype, device=C.device, requires_grad=False)
for i in range(NUM_ITERS):
if i % 2 == 0:
u.sub_(torch.logsumexp(-C + u.unsqueeze(-1) + v.unsqueeze(-2), dim=-1))
else:
v.sub_(torch.logsumexp(-C + u.unsqueeze(-1) + v.unsqueeze(-2), dim=-2))
M = -C + u.unsqueeze(-1) + v.unsqueeze(-2)
M = torch.exp(M)
return M
if __name__ == "__main__":
C = torch.randn(2, 10, 5, device="cuda", dtype=torch.float64, requires_grad=True)
# Check forward passes are the same
print(torch.allclose(fn(C), fn_with_no_grad(C))) # prints True
# Check gradients are as expected
print(torch.autograd.gradcheck(fn, (C,), raise_exception=False)) # prints True
print(
torch.autograd.gradcheck(fn_with_no_grad, (C,), raise_exception=False)
) # prints False
```

As you can see, the `fn_with_no_grad`

version detaches `u`

and `v`

and then essentially treats them as constant during the backward pass (in accordance with Danskin’s theorem).

Since the outputs of the functions match (first check prints True), I believe the objective is in fact converging (empirically it should converge in O(10) iterations on random data). However, the gradient is apparently not correct, as judged by gradcheck.

Could anyone help me demystify what is happening here? I am confused as to why `u`

and `v`

would even contribute to the gradient here, given that they are initialized with `require_grad=False`

. Has anyone had success applying Danskin’s theorem in conjunction with autograd?