Pre-trained model parameters do not update on custom loss

I am trying to extract the trainable parameters of a pre-trained model and then train them over a custom loss function. However, the gradients in this case are None and hence despite the opt.step() there is no update on the model.parameters().


lstTxtOnly = 'text dataset'
model = torch.load('infersent.allnli.pickle')

theta_0 = []
for p in model.parameters():
    theta_0.append(p.data)

max_epochs = 5
learning_rate = 1e-2

for epoch in range(max_epochs+1):

    running_loss = 0.0
    R = torch.nn.functional.softmax(pre_R)
    z_pred = torch.max(R, 1)[1]
    
    X = model.encode(lstTxtOnly, bsize=128, tokenize=False, verbose=False)
    
    # Losses
    custom_loss = 'normal kmeans loss'
    frob_norm_squared = 0
    for (p1,p2) in zip(list(model.parameters()), theta_0):
        frob_norm_squared += torch.sum(torch.abs(p1.data - p2) **2)
    regularizer = frob_norm_squared

    # Optimizer
    opt_F= torch.optim.Adam(model.parameters(), lr = learning_rate)
    
    # zero the parameter gradients
    opt_F.zero_grad()   
    
    loss = custom_loss  + regularizer  
    loss.backward(retain_graph=True)

    for param in model.parameters():
        print(param.grad)
    
    opt_F.step()

I have already checked some other posts where one tries to update the model parameters manually but in my case the gradient itself is zero.

If I remove the retain_graph=True argument the code breaks. Can you please advice how to debug this issue?

Two notes for now: (these probably won’t fix your problem, but it’s probably worth a shot…)

  1. You should probably move the initialization of the optimizer outside of your training loop. This is because optimizers can store state.
  2. If loss.backward() results in None losses, I can think of two things,
  • model.parameters() aren’t leaf nodes in the computation graph (Variables that don’t come from operations on variables)
  • The loss doesn’t need model.parameters().
1 Like

The following code doesn’t actually store a copy of the parameters as they are before the training loop begins, it just stores references to the underlying tensors.

for p in model.parameters():
    theta_0.append(p.data)

You would need to do theta_0.append(p.data.clone()) to store a copy of the parameter data.

Therefore frob_norm_squared will equal zero…

for (p1,p2) in zip(list(model.parameters()), theta_0):
        frob_norm_squared += torch.sum(torch.abs(p1.data - p2) **2)

because p2 is a reference to the same tensor referenced by p1.data

Besides the calculation of frob_norm_squared only uses Tensors, never Variables, therefore it can’t be backpropagated.

So, your regularizer does nothing, and I can only suppose that the ‘normal kmeans loss’ contains some other errors that prevent backpropagation from occurring correctly.

1 Like

Firstly, thank you so much @richard and @jpeg729 for the apt suggestions and the explanations for the same. I really appreciate you helping me resolve this issue. You were right, moving the optimizer outside the loop and the regularizer in fact does nothing!

I think the loss stays constant because of the normal k-means loss component that I calculate. The desired behaviour is when I use the pre-trained model that spits out some embeddings (X) which is converted to a variable like,

lstTxtOnly = 'text dataset'
R_init_numpy = 'some initial values'
R_init_means =  'some other initial values'
pre_R = Variable(torch.from_numpy(R_init_numpy + 1e-8).type(dtype), requires_grad=True)
U = Variable(torch.from_numpy(R_init_means).type(dtype), requires_grad=False)

F = torch.stack(Variable(torch.from_numpy(X).type(dtype), requires_grad=True))

where,
X = model.encode(lstTxtOnly, bsize=128, tokenize=False, verbose=False)

Then, the loss is computed over this variable F as.

distances = torch.sum(((F.unsqueeze(1) - U) ** 2), dim=2)
custom_loss = torch.sum(R * distances) / num_samples

But since the model parameters never update so the embedding value X also doesn’t change. F which was constructed from a numpy(X) won’t back-propagate.

How can I formulate this setup where the pre-trained model parameters are updated to spit out a new embedding that is then used to compute the custom loss and this is repeated for a fixed number of epochs?

I do not understand the need for the line

F = torch.stack(Variable(torch.from_numpy(X).type(dtype), requires_grad=True))

If I understand your code correctly, X = model output, so it must be a Variable already. So when you rewrap it in a Variable you cut off the gradient flow. Besides if you give only one argument to torch.stack then it passes that argument through untouched.

If you replace the above line with

F = X

then the contents of F will be as before, but the gradient flow will no longer be cut off.

1 Like

Sadly, it isn’t. I am using the Infersent model from FAIR.

X= model.encode(sentences, tokenize=True)
This will output an numpy array with n vectors of dimension 4096 (dimension of the sentence embeddings).
Hence, I need to wrap it to a Variable.

Then I see no way in which you can update the parameters using pytorch.

1 Like

Ah! That’s disheartening. I will try to think of some other work-around. Nonetheless, thank you for looking into this problem and for all your help and suggestions.

I pulled out the code from the encode function which has model.forward(). I think this can be used to get the Variable and backpropagate the error to update the parameters.

sentences, lengths, idx_sort = model.prepare_samples(lstSentences, bsize=128, tokenize=True, verbose=True)
batch = Variable(model.get_batch(sentences))

In the encode and visualize function of the original model file they always use volatile=False while extracting the batch variable. However, in this case, we need not set it to True. However, I get a memory error the moment I do this,

if model.is_cuda():
    batch = batch.cuda()
embeddingsTxtVar = model.forward((batch, lengths))


RuntimeError                              Traceback (most recent call last)
<ipython-input-13-76f33c27e0af> in <module>()
      1 if model.is_cuda():
      2     batch = batch.cuda()
----> 3 embeddingsTxtVar = model.forward((batch, lengths))

~/notebooks/code/InferSent/encoder/models.py in forward(self, sent_tuple)
     51         # Handling padding in Recurrent Networks
     52         sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len)
---> 53         sent_output = self.enc_lstm(sent_packed)[0]  # seqlen x batch x 2*nhid
     54         sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0]
     55 

/anaconda/envs/py35/lib/python3.5/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    222         for hook in self._forward_pre_hooks.values():
    223             hook(self, input)
--> 224         result = self.forward(*input, **kwargs)
    225         for hook in self._forward_hooks.values():
    226             hook_result = hook(self, input, result)

/anaconda/envs/py35/lib/python3.5/site-packages/torch/nn/modules/rnn.py in forward(self, input, hx)
    160             flat_weight=flat_weight
    161         )
--> 162         output, hidden = func(input, self.all_weights, hx)
    163         if is_packed:
    164             output = PackedSequence(output, batch_sizes)

/anaconda/envs/py35/lib/python3.5/site-packages/torch/nn/_functions/rnn.py in forward(input, *fargs, **fkwargs)
    349         else:
    350             func = AutogradRNN(*args, **kwargs)
--> 351         return func(input, *fargs, **fkwargs)
    352 
    353     return forward

/anaconda/envs/py35/lib/python3.5/site-packages/torch/autograd/function.py in _do_forward(self, *input)
    282         self._nested_input = input
    283         flat_input = tuple(_iter_variables(input))
--> 284         flat_output = super(NestedIOFunction, self)._do_forward(*flat_input)
    285         nested_output = self._nested_output
    286         nested_variables = _unflatten(flat_output, self._nested_output)

/anaconda/envs/py35/lib/python3.5/site-packages/torch/autograd/function.py in forward(self, *args)
    304     def forward(self, *args):
    305         nested_tensors = _map_variable_tensor(self._nested_input)
--> 306         result = self.forward_extended(*nested_tensors)
    307         del self._nested_input
    308         self._nested_output = result

/anaconda/envs/py35/lib/python3.5/site-packages/torch/nn/_functions/rnn.py in forward_extended(self, input, weight, hx)
    291             hy = tuple(h.new() for h in hx)
    292 
--> 293         cudnn.rnn.forward(self, input, hx, weight, output, hy)
    294 
    295         self.save_for_backward(input, hx, weight, output)

/anaconda/envs/py35/lib/python3.5/site-packages/torch/backends/cudnn/rnn.py in forward(fn, input, hx, weight, output, hy)
    289                 ctypes.byref(reserve_size)
    290             ))
--> 291             fn.reserve = torch.cuda.ByteTensor(reserve_size.value)
    292 
    293             check_error(lib.cudnnRNNForwardTraining(

RuntimeError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1503963423183/work/torch/lib/THC/generic/THCStorage.cu:66

I get the same error when I explicitly put volatile=False or requires_grad=True or both as an additional argument/parameter to
batch = Variable(model.get_batch(sentences), 'either or both additional parameter as above')

Any idea about what I maybe doing wrong here? Please excuse me, I am very new to PyTorch.

And if I keep volatile=True in
batch = Variable(model.get_batch(sentences), volatile=True)

I get the below error and that was the main reason I was trying to set it to False or remove it altogether among other options that I listed in my previous reply/update on this question. I have also tried looking up for similar question both on this forum and StackOverflow.


RuntimeError                              Traceback (most recent call last)
<ipython-input-26-cf5dced5a7fc> in <module>()
     58     loss = total_loss + regularizer
     59 
---> 60     loss.backward()
     61 #     loss.backward(retain_graph=True)

/anaconda/envs/py35/lib/python3.5/site-packages/torch/autograd/variable.py in backward(self, gradient, retain_graph, create_graph, retain_variables)
    154                 Variable.
    155         """
--> 156         torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
    157 
    158     def register_hook(self, hook):

/anaconda/envs/py35/lib/python3.5/site-packages/torch/autograd/__init__.py in backward(variables, grad_variables, retain_graph, create_graph, retain_variables)
     96 
     97     Variable._execution_engine.run_backward(
---> 98         variables, grad_variables, retain_graph)
     99 
    100 

RuntimeError: element 0 of variables tuple is volatile