How does one make sure that the parameters are update manually in pytorch using modules?

How does one make sure that the updates for parameters indeed happens when one subclasses nn modules (or uses torch.nn.Sequential)? I tried making my own class but I was never able to update the parameters for some reason. The SGD code for the nn module is (https://github.com/brando90/simple_regression/blob/master/minimum_example.py):

mdl_sgd = torch.nn.Sequential( torch.nn.Linear(D_sgd,1,bias=False) )
...
for i in range(nb_iter):
    # Forward pass: compute predicted Y using operations on Variables
    batch_xs, batch_ys = get_batch2(X,Y,M,dtype) # [M, D], [M, 1]
    ## FORWARD PASS
    y_pred = mdl_sgd.forward(X)
    ## LOSS
    loss = (1/N)*(y_pred - batch_ys).pow(2).sum()
    ## Manually zero the gradients after updating weights
    mdl_sgd.zero_grad()
    ## BACKARD PASS
    loss.backward() # Use autograd to compute the backward pass. Now w will have gradients
    ## SGD update
    for W in mdl_sgd.parameters():
        #print(W.grad.data)
        W.data = W.data - eta*W.grad.data

which does not work for some unknown reason to me, though when I create the variables explicitly then the updates do happen (https://github.com/brando90/simple_regression/blob/master/direct_example.py):

X = poly_kernel_matrix(x_true,Degree_mdl) # maps to the feature space of the model
X = Variable(torch.FloatTensor(X).type(dtype), requires_grad=False)
Y = Variable(torch.FloatTensor(Y).type(dtype), requires_grad=False)
w_init=torch.randn(D_sgd,1).type(dtype)
W = Variable( w_init, requires_grad=True)
...
for i in range(nb_iter):
        # Forward pass: compute predicted Y using operations on Variables
        batch_xs, batch_ys = get_batch2(X,Y,M,dtype) # [M, D], [M, 1]
        ## FORWARD PASS
        #y_pred = mdl_sgd.forward(X)
        y_pred = batch_xs.mm(W)
        ## LOSS
        loss = (1/N)*(y_pred - batch_ys).pow(2).sum()
        ## BACKARD PASS
        loss.backward() # Use autograd to compute the backward pass. Now w will have gradients
        ## SGD update
        W.data = W.data - eta*W.grad.data
        ## Manually zero the gradients after updating weights
        #mdl_sgd.zero_grad()
        W.grad.data.zero_()

I am not 100% sure what I am doing wrong but if do know feel free to tell me!

I also made a more detailed SO question since I’ve received very good responses from SO in the past (https://stackoverflow.com/questions/45626848/how-does-one-make-sure-that-the-parameters-are-update-manually-in-pytorch-using).

1 Like

Generally you shouldn’t be reassigning .data of Variables, but it should work I think. All built in optimizers do the update in-place (W.data.sub_(lr*W.grad.data))

2 Likes

then how should I be updating the variables if I want to do it manually? (note I choose SGD as an example since I knew what should happen but I really wanted to play around with different update rules)


Note: I now tried the update rule u suggested, didn’t seem to work:

W.data.sub_(eta*W.grad.data)

I wonder if there is something really small and weird that I am doing that makes code thats seems that it should work not work…

note that I am re-asigning .data because the tutorials do it:

http://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-nn

for param in model.parameters():
    param.data -= learning_rate * param.grad.data

and

w1.data -= learning_rate * w1.grad.data
w2.data -= learning_rate * w2.grad.data

Also, why should we not be re-assigning .data?

1 Like

Note that -= stands for inplace subtraction, so you are not re-assigning to it.

does implace subtraction mean w.sub_(X) or what does it mean exactly?

Have a look at my answer in What is the recommended way to re-assign/update values in a variable (or tensor)?

Whenever we have an underscore in the end of a function in pytorch, that means that the function is in-place.
So x.sub_(w) is the same as x -= w for x and w tensors.

2 Likes

so in place means like the usual in place as in normal algorithms? e.g. [2,1]->[1,2] in place means to me that the 1 and 2 are swapped places without creating copies and copies of the objects.

in-place in here means that there is no extra memory allocation.
One more example, when you do

a = a + 1

you allocate a new tensor whose value is a + 1, and you assign it to a tensor called a, overwriting the previous reference to the tensor a. Still, the memory for a+1 had to be freshly allocated.
But when you do

a += 1

no extra memory is allocated, and the addition is performed directly in the original elements of a.

2 Likes

@fma makes sense. What intrigues me right now is why

W = W - eta*W.grad

does not update my parameters. I understand the issue you mentioned about new variable allocation etc but despite of that gradient descent should still work (just memory inefficiently)…this is puzzling me.

1 Like

The reason why it doesn’t update your parameters is simple: you have references of your parameters elsewhere, and you are overwriting the variable that is supposed to reference your parameter.
Simple example:

a = Variable(torch.rand(2))  # for example create in a Module

b = a  # when you get the parameter

# now perform operation in b
b = b + 1

print(a)
print(b) # they differ!
4 Likes

Thanks that makes sense conceptually.

Also, now I actually figured out what was wrong with some other code I was talking about. If I change the update rule to:

W = W - eta*W.grad

it doesn’t work. The reason is because the above is actually nested in a loop that fetches parameters from my Sequential model, so doing (first):

#1st
for W in mdl.parameters():
    W = W - eta*W.grad

vs

#2nd
for W in mdl.parameters():
    W.data = W.data - eta*W.grad.data

even if both are conceptually wrong its semantically very different. The first one replaces a temporary variable/name with a new variable. Since the original variable held in mdl is never updated the model looks as if it never trains. Thus the 2nd method makes the model actually “work” and get trained but the first one does not. This is because the second actually changes variables inside of mdl.

I know neither is the way pytorch is meant to work (it seems) but at least I understand the behaviour I am seeing now as I change the lines between:

    for W in mdl_sgd.parameters():
        #print(W.grad.data)
        W = W - eta*W.grad
        #W.data.sub_(eta*W.grad.data)
        #W.data = W.data - eta*W.grad.data

from https://stackoverflow.com/questions/45626848/how-does-one-make-sure-that-the-parameters-are-update-manually-in-pytorch-using#comment78249041_45626848

1 Like

oh wow you beat me for like 4 minutes. Yea I figured that out. Essentially since b=b+1 creates a new variable with a new python id it doesn’t mean that a changes (which in ur example we are assuming a is bounded inside some class or module). Thanks so much for ur patience! :smiley: