What's the current idiomatic way to SGD stepping w/o using x.grad.data?

What is the current way to write an SGD-like code without using grad.data, and why is .data not documented anywhere (other than the 0.4 migration guide).

Here is a simple code that solves Ax=b using SGD w/ autograd. I can’t find an elegant way to implement the descent-part w/o using grad.data. Looking at optimizers in the pytorch code base, they all use grad.data.

def update1(x):
    with torch.no_grad():
        tmp = x - 0.0001*x.grad
    x.grad.zero_()
    return tmp.requires_grad_(True)

def update2(x):
    x.data -=  0.0001*x.grad.data 
    x.grad.data.zero_()
    return x

# given A and b, and Ax=b: find x using SGD
A = torch.randn(6, 6, requires_grad=False)
b = torch.randn(6, 1, requires_grad=False)

# choose random x and then search with SGD to find x-hat that is the closest to x in Ax=b
x = torch.randn(6, 1, requires_grad=True)

for i in range(40000):
    loss = torch.norm(A @ x - b)
    loss.backward()
    x = update1(x)
    #x = update2(x)
    if not i % 5000: print(f"{loss: >.6f}    {x}")

update2 is how the core optimizers do it, but if grad.data shouldn’t be used, what’s the idiomatic way to do that? I tried a clumsy workaround in update1 that works, but it looks terrible. And if the idiomatic way is to continue using grad.data, then why is it not documented?

p.s. I understand the potential dangers of using grad.data as explained in the 0.4. migration notes - this is not what I’m asking about.

Thank you!

I did not find any way of bypassing the grad.data, but what I did was avoiding the resetting of grad in the update1 function

def update1(x):
    with torch.no_grad():
        tmpGrad=x.grad
        x = x - 0.0001*x.grad
        x.grad=tmpGrad
    return x.requires_grad_(True)

def update2(x):
    x.data -=  0.0001*x.grad.data 
    x.grad.data.zero_()
    x.grad.zero_()
    return x

# given A and b, and Ax=b: find x using SGD
A = torch.randn(6, 6, requires_grad=False)
b = torch.randn(6, 1, requires_grad=False)

# choose random x and then search with SGD to find x-hat that is the closest to x in Ax=b
x = torch.randn(6, 1, requires_grad=True)

for i in range(10000):
    loss = torch.norm(A @ x - b)
    loss.backward()
    x = update1(x)
    x = update2(x)
    if not i % 100: print(f"{loss: >.6f}")

Hope this helps

1 Like

Hi Stas,

good to see you around! So drop the .data, wrap in no_grad works here as far as I can see:

def perfect_update(x):
    with torch.no_grad():
        x -= 0.0001 * x.grad
    x.grad.zero_()
    return x

(I’m having reservations about calling it update when it also zeros grad and about returning x, too, but hey.)

There are two things you previously achieved by using .data:

  • We want to have no autograd graph update for the subtraction. This we can get without data by wrapping the subtraction in with torch.no_grad(). It will preserve the annotation that x requires gradients.
  • The other thing that .data does it inhibits marking x as in-place updated. But it is in fact the right thing to have x marked as modified. The stuff needing the old x will complain (rightfully, you have a bug in your code if you still need the old x when it is gone), everything else will happily take the new x.

The prerequisite for all this is that x is a leaf in the autograd graph, but that should better be the case anyway.

Work for dropping .data is tracked at #30987. Personally, I think one could just go ahead and tick off the things from the list that can be done today, but given that I’m not the one doing it…

Best regards

Thomas

1 Like

Thank you, Thomas! You too!

Oh, I see, this way wasn’t working since I was trying the direct: x = x - lr*grad approach:

    with torch.no_grad():
        x = x - 0.0001 * x.grad
    _= x.grad.zero_()

which fails: AttributeError: 'NoneType' object has no attribute 'zero_' in the x.grad.zero_() call, so I guess in this case in-place subtraction is what I was missing. Then I switched to in-place, and forgot to try this approach, and other ways were failing.

So yes, your perfect_update does the trick!

In-place modification of x is a must then.

Thank you, Thomas.

(I’m having reservations about calling it update when it also zeros grad and about returning x, too, but hey.)

Well, that was just an abstraction I added to make my question more readable with the different x update versions. You’re right, it’s not an update.

the full clean solution now goes:

# given A and b, and Ax=b: find x using SGD
A = torch.randn(6, 6, requires_grad=False)
b = torch.randn(6, 1, requires_grad=False)

# choose random x and then search with SGD to find x-hat that is the closest to x in Ax=b
x = torch.randn(6, 1, requires_grad=True)

for i in range(40000):
    loss = torch.norm(A @ x - b)
    loss.backward()
    with torch.no_grad():
        x = x - 0.0001 * x.grad
    _= x.grad.zero_()
    if not i % 5000: print(f"{loss: >.6f}    {x.T}")

Excellent!

That thread answers the other part of the my question of why .data is not documented. And the effort to remove the usage of .data from the pytorch code.

And this comment of that thread summarizes the situations where it’s still used.

Thank you, Thomas.