Upgrade issue with Trace and Autograd (from torch 1.0.1 to torch 1.3.1)

Hi, I’ve been trying to migrate some code from torch 1.0.1 to torch 1.3.1 but I’m struggling with an error that appears only with the latest version.

Here is a minimal reproducible code to illustrate my problem (just a random MLP with embedding layers for first 3 categorical columns) :

import torch
class NN_emb_long(torch.nn.Module):
    def __init__(self,):
        super(NN_emb_long, self).__init__()
        emb_layers = []
        for i in range(3):
            emb_layers.append(torch.nn.Embedding(5, 2))
        self.emb_layers = torch.nn.ModuleList(emb_layers)        
        self.lin1 = torch.nn.Linear(8, 16)
        self.lin_out = torch.nn.Linear(16, 2)
    
    def forward(self, x):
        embs_list = []
        for i in range(3):
            embs = self.emb_layers[i](x[:,i].long())
            embs_list.append(embs)
        post_embed = torch.cat([x[:, torch.Tensor([3,4]).long()]]+embs_list, dim=1)
        res = self.lin1(post_embed)
        res = torch.nn.ReLU()(res)
        res = self.lin_out(res)
        return res, post_embed

NN = NN_emb_long()
input_example = torch.ones((10, 5)).requires_grad_(True)

probas, post_embeddings = NN(input_example)
grad_outputs = torch.ones(input_example.shape[0],2)
G = torch.autograd.grad(outputs=probas,
                        inputs=post_embeddings,
                        grad_outputs=grad_outputs,
                        only_inputs=True,
                        retain_graph=True
                       )[0]
print(G.shape)
# until this everything should work on both version
# taking the trace shows an error only with torch 1.3.1

basic_trace = torch.jit.trace(NN, input_example, check_trace=True)
probas,  post_embeddings = basic_trace(input_example)
grad_outputs = torch.ones(10,2)
G = torch.autograd.grad(outputs=probas,
                        inputs=post_embeddings,
                        grad_outputs=grad_outputs,
                        only_inputs=True,
                        retain_graph=True,
                       )[0]
print(G.shape)

This code should run fine with torch 1.0.1 but fail with the following error with torch 1.3.1:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-11-d55fb440d76f> in <module>
      6                         grad_outputs=grad_outputs,
      7                         only_inputs=True,
----> 8                         retain_graph=True,
      9                        )[0]
     10 print(G.shape)

.cache/poetry/engine-py3.6/lib/python3.6/site-packages/torch/autograd/__init__.py in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused)
    155     return Variable._execution_engine.run_backward(
    156         outputs, grad_outputs, retain_graph, create_graph,
--> 157         inputs, allow_unused)
    158 
    159 

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Could you please help me understand what happens here?

Thank you!

2 Likes

Looks like my issue went under the radars! Anyone? Any idea?

Hey sorry about the delay! Not sure what the actual bug is, but this code seems to work fine again for me on master. Could you try it with the nightly version of PyTorch or try with 1.4 (which should be released very soon!)?

Thank you, it seems that the problem is indeed solved by upgrading directly to 1.4.

1 Like

Hello @driazati

I work with @Sebastien_Fischman and indeed it works in 1.4.0 now.
But we still have a problem : the models we trained beforehand on version 1.0.1 and saved as a JIT does not work anymore, when trying to load them.

I searched for a while, and I found that inside the JIT the problematic code was

_12 = torch.index(_11, [annotate(Tensor, None), cont_idxs])

With new version, code would be something like

_12 = torch.index(_11, annotate(List[Optional[Tensor]],[None, cont_idxs]))

If I changed manually the code for this new line, it works.

Problem is we are using PyTorch models in production, and manually modifying the models will be an issue.

Do you have a migration script to update the JIT file ? or any solution for JIT to be retrocompatible ?

1 Like

Are you able to share the model file that was saved on 1.0.1? The JIT should always be 100% backwards compatible so this sounds like a bug.

Sure, here is the JIT file : https://srv-file9.gofile.io/download/TCKSJe/9e06eca9fd9c2e.pt
I could not add it to the message.

This JIT was also not working in 1.3 version also @driazati

@driazati looks like sharing a model with this kind of link is a bad idea :slight_smile:

@Hartorn account has been locked for sharing this ^^ but it’s safe to open if you wish to have a look at it!

Hello @driazati,
Sorry to be a bother, but did you have time to have a look at the model ?
Do you think this bug of pytorch will be fixed, or should we find a way to update/modify our JIT file to be compatible ?

Regards

1 Like

Hey, I wasn’t able to download the file (I get a "You are not authorized to download this file " error), but I think I reproduced the same issue. I think this patch should fix the bug. If you are able to build PyTorch from source (see below), you can verify that it fixes the issue

git clone https://github.com/pytorch/pytorch --depth=1
cd pytorch
git fetch origin driazati/fix_annotate
git reset --hard origin/driazati/fix_annotate
# Now build PyTorch from source
# https://github.com/pytorch/pytorch/#from-source
1 Like

thanks @driazati, I think we’ll go for a migration script on our side anyway! Thanks for the help, really appreciate it.

1 Like