How does one make sure that the updates for parameters indeed happens when one subclasses nn modules (or uses torch.nn.Sequential)
? I tried making my own class but I was never able to update the parameters for some reason. The SGD code for the nn module is (https://github.com/brando90/simple_regression/blob/master/minimum_example.py):
mdl_sgd = torch.nn.Sequential( torch.nn.Linear(D_sgd,1,bias=False) )
...
for i in range(nb_iter):
# Forward pass: compute predicted Y using operations on Variables
batch_xs, batch_ys = get_batch2(X,Y,M,dtype) # [M, D], [M, 1]
## FORWARD PASS
y_pred = mdl_sgd.forward(X)
## LOSS
loss = (1/N)*(y_pred - batch_ys).pow(2).sum()
## Manually zero the gradients after updating weights
mdl_sgd.zero_grad()
## BACKARD PASS
loss.backward() # Use autograd to compute the backward pass. Now w will have gradients
## SGD update
for W in mdl_sgd.parameters():
#print(W.grad.data)
W.data = W.data - eta*W.grad.data
which does not work for some unknown reason to me, though when I create the variables explicitly then the updates do happen (https://github.com/brando90/simple_regression/blob/master/direct_example.py):
X = poly_kernel_matrix(x_true,Degree_mdl) # maps to the feature space of the model
X = Variable(torch.FloatTensor(X).type(dtype), requires_grad=False)
Y = Variable(torch.FloatTensor(Y).type(dtype), requires_grad=False)
w_init=torch.randn(D_sgd,1).type(dtype)
W = Variable( w_init, requires_grad=True)
...
for i in range(nb_iter):
# Forward pass: compute predicted Y using operations on Variables
batch_xs, batch_ys = get_batch2(X,Y,M,dtype) # [M, D], [M, 1]
## FORWARD PASS
#y_pred = mdl_sgd.forward(X)
y_pred = batch_xs.mm(W)
## LOSS
loss = (1/N)*(y_pred - batch_ys).pow(2).sum()
## BACKARD PASS
loss.backward() # Use autograd to compute the backward pass. Now w will have gradients
## SGD update
W.data = W.data - eta*W.grad.data
## Manually zero the gradients after updating weights
#mdl_sgd.zero_grad()
W.grad.data.zero_()
I am not 100% sure what I am doing wrong but if do know feel free to tell me!
I also made a more detailed SO question since I’ve received very good responses from SO in the past (https://stackoverflow.com/questions/45626848/how-does-one-make-sure-that-the-parameters-are-update-manually-in-pytorch-using).