RuntimeError: tensor does not have a device?

What’s wrong?

        print("loss: (1)",loss)
        print("penalty: ",penalty)
        loss += penalty
        print("loss: (2)",loss)
        loss.backward()

I got

loss: (1) tensor(5.2173, device='cuda:0', grad_fn=<NllLossBackward>)
penalty:  tensor(0.1816, device='cuda:0', grad_fn=<DivBackward0>)
loss: (2) tensor(5.3989, device='cuda:0', grad_fn=<AddBackward0>)

Traceback (most recent call last):
    loss.backward()
  File ".../torch/tensor.py", line 166, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File ".../torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: tensor does not have a device

An idea why I get this RuntimeError?

I tried making some similar inputs and couldn’t reproduce, so we might need more context, or maybe this is a bug that was fixed in a more recent version of PyTorch:

In [31]: output
Out[31]: tensor(1.2804, device='cuda:0', grad_fn=<NllLossBackward>)
In [32]: penalty
Out[32]: tensor(3.3333, device='cuda:0', grad_fn=<DivBackward0>)
In [33]: output += penalty
In [34]: output.backward()
In [35]: output
Out[35]: tensor(4.6137, device='cuda:0', grad_fn=<AddBackward0>)

Thanks for the interest @nairbv
Myself I have extracted the few lines to another script to debug, and cannot reproduce the RunTimeError, while running my original script always crash at the first batch.

Does someone can tell in which cases the “RuntimeError: tensor does not have a device” is raised. Because I have dumped all the tensors involved in the computation of the loss and all are in the cuda:0 device.
Thanks

Hi, I have some new materials

import torch
import torch.nn as nn
from torch.autograd import grad
import numpy as np

class PzConv2d(nn.Module):
    """ Convolution 2D Layer followed by PReLU activation
    """
    def __init__(self, n_in_channels, n_out_channels, **kwargs):
        super(PzConv2d, self).__init__()
        self.conv = nn.Conv2d(n_in_channels, n_out_channels, bias=True,
                            **kwargs)
        nn.init.xavier_uniform_(self.conv.weight)
        nn.init.constant_(self.conv.bias,0.1)
        self.activ = nn.PReLU(num_parameters=n_out_channels, init=0.25)

    def forward(self, x):
        x = self.conv(x)
        return self.activ(x)


class PzPool2d(nn.Module):
    """ Average Pooling Layer
    """
    def __init__(self, kernel_size, stride, padding=0):
        super(PzPool2d, self).__init__()
        self.pool = nn.AvgPool2d(kernel_size=kernel_size,
                                 stride=stride,
                                 padding=padding,
                                 ceil_mode=True,
                                 count_include_pad=False)

    def forward(self, x):
        return self.pool(x)


class PzFullyConnected(nn.Module):
    """ Dense or Fully Connected Layer followed by ReLU
    """
    def __init__(self, n_inputs, n_outputs, withrelu=True, **kwargs):
        super(PzFullyConnected, self).__init__()
        self.withrelu = withrelu
        self.linear = nn.Linear(n_inputs, n_outputs, bias=True)
        nn.init.xavier_uniform_(self.linear.weight)
        nn.init.constant_(self.linear.bias, 0.1)
        self.activ = nn.ReLU()

    def forward(self, x):
        x = self.linear(x)
        if self.withrelu:
            x = self.activ(x)
        return x


class PzInception(nn.Module):
    """ Inspection module

        The input (x) is dispatched between

        o a cascade of conv layers s1_0 1x1 , s2_0 3x3
        o a cascade of conv layer s1_2 1x1, followed by pooling layer pool0 2x2
        o a cascade of conv layer s2_2 1x1
        o optionally a cascade of conv layers s1_1 1x1, s2_1 5x5

        then the 3 (or 4) intermediate outputs are concatenated
    """
    def __init__(self, n_in_channels, n_out_channels_1, n_out_channels_2,
                 without_kernel_5=False, debug=False):
        super(PzInception, self).__init__()
        self.debug = debug
        self.s1_0 = PzConv2d(n_in_channels, n_out_channels_1,
                             kernel_size=1, padding=0)
        self.s2_0 = PzConv2d(n_out_channels_1, n_out_channels_2,
                             kernel_size=3, padding=1)

        self.s1_2 = PzConv2d(n_in_channels, n_out_channels_1, kernel_size=1)
        self.pad0 = nn.ZeroPad2d([0, 1, 0, 1])
        self.pool0 = PzPool2d(kernel_size=2, stride=1, padding=0)

        self.without_kernel_5 = without_kernel_5
        if not (without_kernel_5):
            self.s1_1 = PzConv2d(n_in_channels, n_out_channels_1,
                                 kernel_size=1, padding=0)
            self.s2_1 = PzConv2d(n_out_channels_1, n_out_channels_2,
                                 kernel_size=5, padding=2)

        self.s2_2 = PzConv2d(n_in_channels, n_out_channels_2, kernel_size=1,
                             padding=0)

    def forward(self, x):
        # x:image tenseur N_batch, Channels, Height, Width
        x_s1_0 = self.s1_0(x)
        x_s2_0 = self.s2_0(x_s1_0)

        x_s1_2 = self.s1_2(x)

        x_pool0 = self.pool0(self.pad0(x_s1_2))

        if not (self.without_kernel_5):
            x_s1_1 = self.s1_1(x)
            x_s2_1 = self.s2_1(x_s1_1)

        x_s2_2 = self.s2_2(x)

        if self.debug: print("Inception x_s1_0  :", x_s1_0.size())
        if self.debug: print("Inception x_s2_0  :", x_s2_0.size())
        if self.debug: print("Inception x_s1_2  :", x_s1_2.size())
        if self.debug: print("Inception x_pool0 :", x_pool0.size())

        if not (self.without_kernel_5) and self.debug:
            print("Inception x_s1_1  :", x_s1_1.size())
            print("Inception x_s2_1  :", x_s2_1.size())

        if self.debug: print("Inception x_s2_2  :", x_s2_2.size())

        # to be check: dim=1=> NCHW (en TensorFlow axis=3 NHWC)
        if not (self.without_kernel_5):
            output = torch.cat((x_s2_2, x_s2_1, x_s2_0, x_pool0), dim=1)
        else:
            output = torch.cat((x_s2_2, x_s2_0, x_pool0), dim=1)

        if self.debug: print("Inception output :", output.shape)
        return output


class NetWithInception(nn.Module):
    """ The Networks
        inputs: the image (x), the reddening vector


        The image 64x64x5 is fed forwardly throw
        o a conv layer 5x5
        o a pooling layer 2x2
        o 5 inspection modules with the last one including a 5x5 part

        Then, we concatenate the result with the reddening vector to perform
        o 3 fully connected layers

        The output dimension is given by n_bins
        There is no activation softmax here to allow the use of Cross Entropy loss

    """
    def __init__(self, n_input_channels, debug=False):
        super(NetWithInception, self).__init__()
        
        # the number of bins to represent the output photo-z
        self.n_bins = 180

        self.debug = debug
        self.conv0 = PzConv2d(n_in_channels=n_input_channels,
                              n_out_channels=64,
                              kernel_size=5, padding=2)
        self.pool0 = PzPool2d(kernel_size=2, stride=2, padding=0)
        # for the Softmax the input tensor shape is [1,n] so apply on axis=1
        # t1 = torch.rand([1,10])
        # t2 = nn.Softmax(dim=1)(t1)
        # torch.sum(t2) = 1
        self.i0 = PzInception(n_in_channels=64,
                              n_out_channels_1=48,
                              n_out_channels_2=64)

        self.i1 = PzInception(n_in_channels=240,
                              n_out_channels_1=64,
                              n_out_channels_2=92)

        self.i2 = PzInception(n_in_channels=340,
                              n_out_channels_1=92,
                              n_out_channels_2=128)

        self.i3 = PzInception(n_in_channels=476,
                              n_out_channels_1=92,
                              n_out_channels_2=128)

        self.i4 = PzInception(n_in_channels=476,
                              n_out_channels_1=92,
                              n_out_channels_2=128,
                              without_kernel_5=True)

        self.fc0 = PzFullyConnected(n_inputs=22273, n_outputs=1096)
        self.fc1 = PzFullyConnected(n_inputs=1096, n_outputs=1096)
        self.fc2 = PzFullyConnected(n_inputs=1096, n_outputs=self.n_bins)


    def num_flat_features(self, x):
        """

        Parameters
        ----------
        x: the input

        Returns
        -------
        the totale number of features = number of elements of the tensor except the batch dimension

        """
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

    def forward(self, x, reddening):
        # x:image tenseur N_batch, Channels, Height, Width
        #    size N, Channles=5 filtres, H,W = 64 pixels
        # save original image
        x_in = x

        if self.debug: print("input shape: ", x.size())
        x = self.conv0(x)
        if self.debug: print("conv0 shape: ", x.size())
        x = self.pool0(x)
        if self.debug: print("conv0p shape: ", x.size())
        if self.debug: print('>>>>>>> i0:START <<<<<<<')
        x = self.i0(x)
        if self.debug: print("i0 shape: ", x.size())

        if self.debug: print('>>>>>>> i1:START <<<<<<<')
        x = self.i1(x)

        x = self.pool0(x)
        if self.debug: print("i1p shape: ", x.size())

        if self.debug: print('>>>>>>> i2:START <<<<<<<')
        x = self.i2(x)
        if self.debug: print("i2 shape: ", x.size())

        if self.debug: print('>>>>>>> i3:START <<<<<<<')
        x = self.i3(x)
        x = self.pool0(x)
        if self.debug: print("i3p shape: ", x.size())

        if self.debug: print('>>>>>>> i4:START <<<<<<<')
        x = self.i4(x)
        if self.debug: print("i4 shape: ", x.size())

        if self.debug: print('>>>>>>> FC part :START <<<<<<<')
        flat = x.view(-1, self.num_flat_features(x))
        if self.debug: print("flat shape: ", flat.size())
        concat = torch.cat((flat, reddening), dim=1)
        if self.debug: print('concat shape: ', concat.size())

        fcn_in_features = concat.size(-1)
        if self.debug: print('fcn_in_features: ', fcn_in_features)

        x = self.fc0(concat)
        if self.debug: print('fc0 shape: ', x.size())
        x = self.fc1(x)
        if self.debug: print('fc1 shape: ', x.size())
        x = self.fc2(x)
        if self.debug: print('fc2 shape: ', x.size())

        output = x
        if self.debug: print('output shape: ', output.size())

        #params = {"output": output, "x": x_in, "reddening": reddening}
        # return params

        return output

########

img_channels = 5
img_H = 64
img_W = 64
n_batchs = 1

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print("device: ",device)

loss_fn = torch.nn.CrossEntropyLoss()
loss_fn_sum = torch.nn.CrossEntropyLoss(reduction='sum')
#m = torch.nn.Linear(20, 30)

m = NetWithInception(img_channels,debug=True)
m.to(device)

ims = torch.randn(n_batchs, img_channels,img_H ,img_W,dtype=torch.float)
reds = torch.zeros([n_batchs,1],dtype=torch.float)
target = torch.empty(n_batchs, dtype=torch.long).random_(180)
ims, target, reds = ims.to(device), target.to(device), reds.to(device)

pred =  m(ims, reds)
loss = loss_fn(pred,target)

n = ims.shape[0]
imsv = ims.clone().requires_grad_()
preds = m(imsv, reds)
xeloss = loss_fn(preds, target)
g, = torch.autograd.grad(xeloss, imsv, create_graph=True)
penalty = g.norm(1) / n


print("loss   : ",loss)
print("penalty: ",penalty)

loss += penalty
tmp = loss.item()
print("loss   : ",loss)
#compute grad
loss.backward()
print("end")

Can someone run this code and tell me if he gets the following message:

device:  cuda
input shape:  torch.Size([1, 5, 64, 64])
conv0 shape:  torch.Size([1, 64, 64, 64])
conv0p shape:  torch.Size([1, 64, 32, 32])
>>>>>>> i0:START <<<<<<<
i0 shape:  torch.Size([1, 240, 32, 32])
>>>>>>> i1:START <<<<<<<
i1p shape:  torch.Size([1, 340, 16, 16])
>>>>>>> i2:START <<<<<<<
i2 shape:  torch.Size([1, 476, 16, 16])
>>>>>>> i3:START <<<<<<<
i3p shape:  torch.Size([1, 476, 8, 8])
>>>>>>> i4:START <<<<<<<
i4 shape:  torch.Size([1, 348, 8, 8])
>>>>>>> FC part :START <<<<<<<
flat shape:  torch.Size([1, 22272])
concat shape:  torch.Size([1, 22273])
fcn_in_features:  22273
fc0 shape:  torch.Size([1, 1096])
fc1 shape:  torch.Size([1, 1096])
fc2 shape:  torch.Size([1, 180])
output shape:  torch.Size([1, 180])
input shape:  torch.Size([1, 5, 64, 64])
conv0 shape:  torch.Size([1, 64, 64, 64])
conv0p shape:  torch.Size([1, 64, 32, 32])
>>>>>>> i0:START <<<<<<<
i0 shape:  torch.Size([1, 240, 32, 32])
>>>>>>> i1:START <<<<<<<
i1p shape:  torch.Size([1, 340, 16, 16])
>>>>>>> i2:START <<<<<<<
i2 shape:  torch.Size([1, 476, 16, 16])
>>>>>>> i3:START <<<<<<<
i3p shape:  torch.Size([1, 476, 8, 8])
>>>>>>> i4:START <<<<<<<
i4 shape:  torch.Size([1, 348, 8, 8])
>>>>>>> FC part :START <<<<<<<
flat shape:  torch.Size([1, 22272])
concat shape:  torch.Size([1, 22273])
fcn_in_features:  22273
fc0 shape:  torch.Size([1, 1096])
fc1 shape:  torch.Size([1, 1096])
fc2 shape:  torch.Size([1, 180])
output shape:  torch.Size([1, 180])
loss   :  tensor(5.4763, device='cuda:0', grad_fn=<NllLossBackward>)
penalty:  tensor(0.1425, device='cuda:0', grad_fn=<DivBackward0>)
loss   :  tensor(5.6188, device='cuda:0', grad_fn=<AddBackward0>)
Traceback (most recent call last):
  File "bugs.py", line 301, in <module>
    loss.backward()
  File "...anaconda3/lib/python3.7/site-packages/torch/tensor.py", line 166, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File ".../anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: sizes() called on undefined Tensor

It is a different error… but may be linked to the original one.
Same problem on “cpu” or “cuda”

Hey! what about

RuntimeError: sizes() called on undefined Tensor

This is a very strange message as “sizes()” with an “s” is not a valid function !!!

sizes() would most likely correspond to the C++ function name.

The error is raised, when create_graph=True with:

RuntimeError: Expected a Tensor of type Variable but found an undefined Tensor for argument #0 'self' (checked_cast_variable at /opt/conda/conda-bld/pytorch_1579335088481/work/torch/csrc/autograd/VariableTypeManual.cpp:38)

@albanD Do you have an idea what might have gone wrong?

Hi @ptrblck Thanks for your interest for this new pb.
So, I have no possibility to investigate the C++ source?
The only use of create_graph=True is from the line

g, = torch.autograd.grad(xeloss, imsv, create_graph=True)

This is part of a code originating from https://github.com/albietz/kernel_reg/blob/master/main.py

Hello, I have just

  1. make a fresh install of conda
  2. conda install pytorch torchvision cudatoolkit=10.1 -c pytorch

Then Get a new error :slight_smile:

Traceback (most recent call last):
  File "bugs.py", line 333, in <module>
    loss.backward()
  File "...python3.7/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File ".../python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Expected a Tensor of type Variable but found an undefined Tensor for argument #0 'self'

Hi,

Can you run your code with torch.autograd.set_detect_anomaly(True) please. To know which forward function is responsible for this.

@albanD Thanks for your interest

I have set your “magic” line right after the imports

import torch
import torch.nn as nn
from torch.autograd import grad
import numpy as np

torch.autograd.set_detect_anomaly(True) 

Here is the Traceback

Warning: Traceback of forward call that caused the error:
  File "...anaconda3/lib/python3.7/traceback.py", line 197, in format_stack
    return format_list(extract_stack(f, limit=limit))
 (print_stack at /opt/conda/conda-bld/pytorch_1579022060824/work/torch/csrc/autograd/python_anomaly_mode.cpp:57)
Traceback (most recent call last):
  File "bugs.py", line 335, in <module>
    loss.backward()
  File "...anaconda3/lib/python3.7/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "...anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Expected a Tensor of type Variable but found an undefined Tensor for argument #0 'self'

So new infos. I have run the code on Google Colab.
It uses Python 3.6, Torch 1.3.1 and crash both on CPU and GPU K80 with the same message.

/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:57: UserWarning: Traceback of forward call that caused the error:
  File "/usr/lib/python3.6/traceback.py", line 197, in format_stack
    return format_list(extract_stack(f, limit=limit))

---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-1-e5464bc91872> in <module>()
    333 print("loss   : ",loss)
    334 #compute grad
--> 335 loss.backward()
    336 
    337 print("end")

1 frames

/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     97     Variable._execution_engine.run_backward(
     98         tensors, grad_tensors, retain_graph, create_graph,
---> 99         allow_unreachable=True)  # allow_unreachable flag
    100 
    101 

RuntimeError: sizes() called on undefined Tensor

May be @albanD you have an idea?

Hi,

I ran your code above from RuntimeError: tensor does not have a device? and it runs without any issue (two machines with current master).

Could you try upgrading to 1.4 that just came out to see if you still see this happening?

FYI the issue is hard to track from the error message because it happens during the second backward due to a function that run during the first backward (hence the weird stack trace from the anomaly mode).
The error messages means that some cpp code that expects a Tensor containing value was given an undefined Tensor (which represent a Tensor full of 0s for gradients).

Hi @albanD
In fact I have already upgrade with a completely fresh install of both conda and torch 1.4.0 (it is a message just before your invitation to use torch.autograd.set_detect_anomaly(True).

The RuntimeError message changes from PyTorch 1.3.1 to 1.4.0.

I will give you my config on the computer center (CC) I use, but keep in mind that I have also run trouble in Google Colab environment.

Here is the config at Computer Center:

PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: CentOS Linux release 7.7.1908 (Core)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-39)
CMake version: version 2.8.12.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.105
GPU models and configuration: GPU 0: Tesla V100-PCIE-32GB
Nvidia driver version: 418.87.01
cuDNN version: /opt/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7.5.0

Versions of relevant libraries:
[pip] numpy==1.17.2
[pip] numpydoc==0.9.1
[pip] torch==1.4.0
[pip] torchvision==0.5.0
[conda] blas                      1.0                         mkl  
[conda] mkl                       2019.4                      243  
[conda] mkl-service               2.3.0            py37he904b0f_0  
[conda] mkl_fft                   1.0.14           py37ha843d7b_0  
[conda] mkl_random                1.1.0            py37hd6b4f25_0  
[conda] pytorch                   1.4.0           py3.7_cuda10.1.243_cudnn7.6.3_0    pytorch
[conda] torchvision               0.5.0                py37_cu101    pytorch

Hi,

Here is a much smaller repro:

import torch
import torch.nn as nn
from torch.autograd import grad

print(torch.__version__)

imsv = torch.rand(1, 3, 6, 6, requires_grad=True)

x_s1_2 = nn.PReLU(num_parameters=3)(imsv)

tmp = nn.ZeroPad2d([0, 1, 0, 1])(x_s1_2)
preds = nn.AvgPool2d(kernel_size=2,
                        stride=1,
                        padding=0,
                        ceil_mode=True,
                        count_include_pad=False)(tmp)

g, = grad(preds.sum(), imsv, create_graph=True)

grad(g.sum(), imsv)

I still cannot reproduce with latest master though. Could you try installing the nightly build?
If it still fails for you on the nightly build, please open an issue on github with the short repro and all your install / computer details !

@albanD Thanks for the interest. Can you tell me how I can get the “nightly build” from conda installation ?

You can select “nightly” on the getting started page: https://pytorch.org/get-started/locally/

@albanD

Good news ! With the night build 1.5.0.dev20200129 I have no crash neither on CPU nor on GPU, and both for your code and my longer one :slight_smile:

So the problem has been solved ! May be you can investigate when the bug has been solved by itself…

Thanks a lot.

PS: notice that the first time I have performed an import torch it happens

    import torch
  File "/sps/lsst/users/campagne/anaconda3/lib/python3.7/site-packages/torch/__init__.py", line 125, in <module>
    from torch._C import *
ImportError: .../anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so: undefined symbol: _ZN5torch4cuda4nccl6detail16throw_nccl_errorE12ncclResult_t

What most likely happened here is that we changed some of the gradient formulas to make them handle undefined Tensors properly.
I reviewed the ones used in the sample in master and they all looked good.

1 Like

Ha! zut, the problem still there the loss.backward() crash but I got a longer Traceback for my original script :slight_smile:
I use Pytorch '1.5.0.dev20200129. The crash do not occur neither on the shorter script I have given at the beginning of this thread nor on the smaller exercise given by @albanD. Here is the Trace Back

Traceback (most recent call last):
  File "pz_train_kernelreg.py", line 562, in <module>
    main()
  File "pz_train_kernelreg.py", line 515, in main
    epoch, perturb=perturb)
  File "pz_train_kernelreg.py", line 133, in train
    loss.backward()
  File "...anaconda3/lib/python3.7/site-packages/torch/tensor.py", line 198, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/.../anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: tensor does not have a device (device at /opt/conda/conda-bld/pytorch_1580285301452/work/c10/core/TensorImpl.h:463)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x4e (0x2b0eb9c5c40e in /.../anaconda3/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: at::Tensor::options() const + 0x1ff (0x2b0e81d805df in /.../anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #2: torch::autograd::generated::SliceBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0xea (0x2b0e8bb7e11a in .../anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x2e672fd (0x2b0e8c2e62fd in .../anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #4: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 0x154b (0x2b0e8c2e277b in /.../anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #5: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&, bool) + 0x588 (0x2b0e8c2e3c88 in .../anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #6: torch::autograd::Engine::thread_init(int) + 0x39 (0x2b0e8c2dac99 in /.../anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #7: torch::autograd::python::PythonEngine::thread_init(int) + 0x38 (0x2b0e821d38a8 in .../anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0xc819d (0x2b0e7f93e19d in .../anaconda3/lib/python3.7/site-packages/torch/lib/../../../.././libstdc++.so.6)
frame #9: <unknown function> + 0x7e65 (0x2b0e60e87e65 in /lib64/libpthread.so.0)
frame #10: clone + 0x6d (0x2b0e6119a88d in /lib64/libc.so.6)