RuntimeError: tensor does not have a device?

@albanD
Here is a stand-alone script bugs_v3.py. I have reduced the model as much as possible to exhibit the crash. If you run it as
o python bugs_v3.py --bs 1 : it will not crash :slight_smile:
o python bugs_v3.py --bs 2 : it will crash :frowning:
o python bugs_v3.py --bs 2 --noincept` : it will not crash :slight_smile:

The noinceptoption by pass the Inception cell

I am not enough expert to reduce the code further to exhibit the problem, but as you see the batch_size is critical in presence of the Inception cell. May be you or someone else will find the pb.

import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad
import torch.nn.functional as F
import random
import types
import numpy as np

import os


class PzConv2d(nn.Module):
    """ Convolution 2D Layer followed by PReLU activation
    """
    def __init__(self, n_in_channels, n_out_channels, **kwargs):
        super(PzConv2d, self).__init__()
        self.conv = nn.Conv2d(n_in_channels, n_out_channels, bias=True,
                            **kwargs)
        ## JEC 11/9/19 use default init :
        ##   kaiming_uniform_ for weights
        ##   bias uniform 
        #xavier init for the weights
        ## nn.init.xavier_normal_(self.conv.weight)
        nn.init.xavier_uniform_(self.conv.weight)
        ## constant init for the biais with cte=0.1
        nn.init.constant_(self.conv.bias,0.1)
####        self.bn = nn.BatchNorm2d(n_out_channels, eps=0.001)  #### TEST JEC 4/11/19 for robust training
        self.activ = nn.PReLU(num_parameters=n_out_channels, init=0.25)
        ## self.activ = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
####        x = self.bn(x) #### TEST JEC 4/11/19 for robust training
        return self.activ(x)


class PzPool2d(nn.Module):
    """ Average Pooling Layer
    """
    def __init__(self, kernel_size, stride, padding=0):
        super(PzPool2d, self).__init__()
        self.pool = nn.AvgPool2d(kernel_size=kernel_size,
                                 stride=stride,
                                 padding=padding,
                                 ceil_mode=True,
                                 count_include_pad=False)

    def forward(self, x):
        return self.pool(x)


class PzFullyConnected(nn.Module):
    """ Dense or Fully Connected Layer followed by ReLU
    """
    def __init__(self, n_inputs, n_outputs, withrelu=True, **kwargs):
        super(PzFullyConnected, self).__init__()
        self.withrelu = withrelu
        self.linear = nn.Linear(n_inputs, n_outputs, bias=True)
        ## JEC 11/9/19 use default init :
        ##   kaiming_uniform_ for weights
        ##   bias uniform 

        # xavier init for the weights
        nn.init.xavier_uniform_(self.linear.weight)
##        nn.init.xavier_normal_(self.linear.weight)
        # constant init for the biais with cte=0.1
        nn.init.constant_(self.linear.bias, 0.1)
        self.activ = nn.ReLU()

    def forward(self, x):
        x = self.linear(x)
        if self.withrelu:
            x = self.activ(x)
        return x


class PzInception(nn.Module):
    """ Inspection module

        The input (x) is dispatched between

        o a cascade of conv layers s1_0 1x1 , s2_0 3x3
        o a cascade of conv layer s1_2 1x1, followed by pooling layer pool0 2x2
        o a cascade of conv layer s2_2 1x1
        o optionally a cascade of conv layers s1_1 1x1, s2_1 5x5

        then the 3 (or 4) intermediate outputs are concatenated
    """
    def __init__(self, n_in_channels, n_out_channels_1, n_out_channels_2,
                 without_kernel_5=False, debug=False):
        super(PzInception, self).__init__()
        self.debug = debug
        self.s1_0 = PzConv2d(n_in_channels, n_out_channels_1,
                             kernel_size=1, padding=0)
        self.s2_0 = PzConv2d(n_out_channels_1, n_out_channels_2,
                             kernel_size=3, padding=1)

        self.s1_2 = PzConv2d(n_in_channels, n_out_channels_1, kernel_size=1)
        self.pad0 = nn.ZeroPad2d([0, 1, 0, 1])
        self.pool0 = PzPool2d(kernel_size=2, stride=1, padding=0)

        self.without_kernel_5 = without_kernel_5
        if not (without_kernel_5):
            self.s1_1 = PzConv2d(n_in_channels, n_out_channels_1,
                                 kernel_size=1, padding=0)
            self.s2_1 = PzConv2d(n_out_channels_1, n_out_channels_2,
                                 kernel_size=5, padding=2)

        self.s2_2 = PzConv2d(n_in_channels, n_out_channels_2, kernel_size=1,
                             padding=0)

    def forward(self, x):
        # x:image tenseur N_batch, Channels, Height, Width
        x_s1_0 = self.s1_0(x)
        x_s2_0 = self.s2_0(x_s1_0)

        x_s1_2 = self.s1_2(x)

        x_pool0 = self.pool0(self.pad0(x_s1_2))

        if not (self.without_kernel_5):
            x_s1_1 = self.s1_1(x)
            x_s2_1 = self.s2_1(x_s1_1)

        x_s2_2 = self.s2_2(x)

        if self.debug: print("Inception x_s1_0  :", x_s1_0.size())
        if self.debug: print("Inception x_s2_0  :", x_s2_0.size())
        if self.debug: print("Inception x_s1_2  :", x_s1_2.size())
        if self.debug: print("Inception x_pool0 :", x_pool0.size())

        if not (self.without_kernel_5) and self.debug:
            print("Inception x_s1_1  :", x_s1_1.size())
            print("Inception x_s2_1  :", x_s2_1.size())

        if self.debug: print("Inception x_s2_2  :", x_s2_2.size())

        # to be check: dim=1=> NCHW (en TensorFlow axis=3 NHWC)
        if not (self.without_kernel_5):
            output = torch.cat((x_s2_2, x_s2_1, x_s2_0, x_pool0), dim=1)
        else:
            output = torch.cat((x_s2_2, x_s2_0, x_pool0), dim=1)

        if self.debug: print("Inception output :", output.shape)
        return output


class NetWithInception(nn.Module):
    """ The Networks
        inputs: the image (x), the reddening vector


        The image 64x64x5 is fed forwardly throw
        o a conv layer 5x5
        o a pooling layer 2x2
        o 5 inspection modules with the last one including a 5x5 part

        Then, we concatenate the result with the reddening vector to perform
        o 3 fully connected layers

        The output dimension is given by n_bins
        There is no activation softmax here to allow the use of Cross Entropy loss

    """
    def __init__(self, n_input_channels, no_inception=False, debug=False):
        super(NetWithInception, self).__init__()
        
        # the number of bins to represent the output photo-z
        self.n_bins = 180

        self.debug = debug
        self.conv0 = PzConv2d(n_in_channels=n_input_channels,
                              n_out_channels=64,
                              kernel_size=5, padding=2)
        self.pool0 = PzPool2d(kernel_size=2, stride=2, padding=0)

        self.i0 = PzInception(n_in_channels=64,
                              n_out_channels_1=48,
                              n_out_channels_2=64)



        self.no_inception = no_inception

        if no_inception:
            self.fc0 = PzFullyConnected(n_inputs=65537, n_outputs=self.n_bins)
        else:
            self.fc0 = PzFullyConnected(n_inputs=245761, n_outputs=self.n_bins)


    def num_flat_features(self, x):
        """

        Parameters
        ----------
        x: the input

        Returns
        -------
        the totale number of features = number of elements of the tensor except the batch dimension

        """
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

    def forward(self, x, reddening):
        # x:image tenseur N_batch, Channels, Height, Width
        #    size N, Channles=5 filtres, H,W = 64 pixels
        # save original image
        x_in = x

        if self.debug: print("input shape: ", x.size())
        x = self.conv0(x)
        if self.debug: print("conv0 shape: ", x.size())
        x = self.pool0(x)
        if self.debug: print("conv0p shape: ", x.size())
        if self.debug: print('>>>>>>> i0:START <<<<<<<')

        if not self.no_inception:
            x = self.i0(x)
            
        if self.debug: print("i0 shape: ", x.size())


        if self.debug: print('>>>>>>> FC part :START <<<<<<<')
        flat = x.view(-1, self.num_flat_features(x))
        if self.debug: print("flat shape: ", flat.size())
        concat = torch.cat((flat, reddening), dim=1)
        if self.debug: print('concat shape: ', concat.size())

        fcn_in_features = concat.size(-1)
        if self.debug: print('fcn_in_features: ', fcn_in_features)

        x = self.fc0(concat)
        if self.debug: print('fc0 shape: ', x.size())

        output = x
        if self.debug: print('output shape: ', output.size())


        return output


class LossAvgGradL1(object):
    '''grad-l1 : Gradient l1 norm penalty on the loss.'''
    def __init__(self, model, loss_fn, lmbda):
        assert loss_fn.reduction == 'sum', 'need a sum reduction for the loss'
        self.model = model
        self.loss_fn = loss_fn
        self.lmbda = lmbda

    def loss(self, ims, labels, extra=None):
        n = ims.shape[0]
        imsv = ims.clone().requires_grad_()
        if extra is None:
            preds = self.model(imsv)
        else:
            preds = self.model(imsv,extra)
        xeloss = self.loss_fn(preds, labels)
        g, = grad(xeloss, imsv, create_graph=True)
        return self.lmbda * g.norm(1) / n

############

def train(args, model, device, optimizer, loss_fn,
          epoch, perturb=None, **perturb_args):
    """

    Training phase.

    """

    # switch network layers to Training mode
    model.train()

    n_bins = model.n_bins # the last number of neurons


    train_loss = 0
    # scans the batches
    Nloop = args.nloop
    img_channels = 5
    img_H = 64
    img_W = 64
    n_batchs = args.batch_size
    
    for i_batch in range(Nloop):

        new_img_batch = torch.randn(n_batchs, img_channels,img_H ,img_W,dtype=torch.float)
        ebv_batch = torch.zeros([n_batchs,1],dtype=torch.float)
        new_z_batch = torch.empty(n_batchs, dtype=torch.long).random_(n_bins)

        # send the inputs and target to the device
        new_img_batch, ebv_batch, new_z_batch = new_img_batch.to(device), \
                                            ebv_batch.to(device), \
                                            new_z_batch.to(device)


        # reset the gradiants
        optimizer.zero_grad()
        # Feedforward
        output = model(new_img_batch, ebv_batch)


        # the loss
        loss = loss_fn(output,new_z_batch)

        # add a penalty
        # warning order is img/labels/ebv
        penalty = perturb.loss(new_img_batch, new_z_batch, extra=ebv_batch)
        loss += penalty
        
        
        # backprop to compute the gradients
        loss.backward()
        train_loss += loss.item() * args.batch_size 
        
        # x-check if some strange happens 
#        for pi, p in enumerate(model.parameters()):
#            assert torch.all(torch.isfinite(p.data))
#            assert torch.all(torch.isfinite(p.grad))
        # perform an optimizer step to modify the weights
        optimizer.step()

        # some debug
        print(f"Train Epoch: {epoch}, loss : {loss.item()}") 

    # return some stat
    # return loss.item()
    return train_loss/Nloop

# ####################### MAIN ###########################


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch PhotoZ')
    parser.add_argument('--bs', type=int, default=2, dest='batch_size',
                        help='batch size (1:Ok; 2:Crash')

    parser.add_argument('--epochs', type=int, default=5, dest='epochs',
                        help='number of epochs')

    parser.add_argument('--nl', type=int, default=5, dest='nloop',
                        help='number of loops inside training')

    parser.add_argument('--noincept', dest='no_inception', default=False, action='store_true')

 

    args = parser.parse_args()

    print("\n### Training model ###")
    print("> Parameters:")
    for p, v in zip(args.__dict__.keys(), args.__dict__.values()):
        print('\t{}: {}'.format(p, v))

    
    use_cuda = torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")
    print("Use device....: ",device)

    # Pin_memory speed up the CPU->GPU transfert for NVidia GPU
    kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}


    # The Network
    img_channels = 5 # number of filters
    model = NetWithInception(img_channels,no_inception = args.no_inception)

    # put model to device before loading scheduler/optimizer parameters
    # This method is in-place https://pytorch.org/docs/stable/nn.html#torch.nn.Module.to
    model.to(device)

    # losses
    loss_fn_mean = nn.CrossEntropyLoss()
    loss_fn_sum = nn.CrossEntropyLoss(reduction='sum')

    # The optimize
    # define optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001,
                                momentum=0.9,
                                weight_decay=0.0,
                                nesterov=True)

    # perturbation/penalty
    perturb = LossAvgGradL1(model, loss_fn_sum, lmbda=0.01)

        

    start_epoch = 0

    #### GO!

    for epoch in range(start_epoch, args.epochs + 1):
        
        print("process epoch[",epoch,"]: LR = ",end='')
        for param_group in optimizer.param_groups:
            print(param_group['lr'])

        # training
        train_loss = train(args, model, device, optimizer, loss_fn_mean,
                           epoch, perturb=perturb)
        
        print('Epoch {}, Train Loss: {:.6f}'.format(epoch,train_loss))


    #End
    print("End of job. Bye")

################################
if __name__ == '__main__':
  main()

Breaking News:

If you change PReLU by ReLU in the PzConv2d class

##        self.activ = nn.PReLU(num_parameters=n_out_channels, init=0.25)
        self.activ = nn.ReLU()

then the RunTime error is not raised even running python bugs_v3.py --bs 2.

So an hint : the parameters of the PReLU.

My mininal repro works fine but the full script still fails right?

Yes try my last script to investigate. Thanks

Could you try to do the same thing I did above: remove as much as possible while still reproducing the error? That will help us narrow down the issue.

@albanD I will do my best but I am not an expert… :stuck_out_tongue:

@albanD

I have reduced the model and it seems that if I touch the PzInception forward chain then other kind of crash appears… But I am not an experienced enough to investigate further wo introducing other bugs. May be you will find faster.

class PzConv2d(nn.Module):
    """ Convolution 2D Layer followed by PReLU activation
    """
    def __init__(self, n_in_channels, n_out_channels, **kwargs):
        super(PzConv2d, self).__init__()
        self.conv = nn.Conv2d(n_in_channels, n_out_channels, bias=True,
                            **kwargs)

        self.activ = nn.PReLU(num_parameters=n_out_channels, init=0.25)
##        self.activ = nn.ReLU()   # <===== no bug if used


    def forward(self, x):
        x = self.conv(x)
        return self.activ(x)


class PzPool2d(nn.Module):
    """ Average Pooling Layer
    """
    def __init__(self, kernel_size, stride, padding=0):
        super(PzPool2d, self).__init__()
        self.pool = nn.AvgPool2d(kernel_size=kernel_size,
                                 stride=stride,
                                 padding=padding,
                                 ceil_mode=True,
                                 count_include_pad=False)

    def forward(self, x):
        return self.pool(x)


class PzFullyConnected(nn.Module):
    """ Dense or Fully Connected Layer followed by ReLU
    """
    def __init__(self, n_inputs, n_outputs, withrelu=True, **kwargs):
        super(PzFullyConnected, self).__init__()
        self.withrelu = withrelu
        self.linear = nn.Linear(n_inputs, n_outputs, bias=True)
        self.activ = nn.ReLU()

    def forward(self, x):
        x = self.linear(x)
        if self.withrelu:
            x = self.activ(x)
        return x


class PzInception(nn.Module):

    def __init__(self, n_in_channels, n_out_channels_1, n_out_channels_2,
                 without_kernel_5=False, debug=False):
        super(PzInception, self).__init__()
        self.debug = debug
        self.s1_0 = PzConv2d(n_in_channels, n_out_channels_1,
                             kernel_size=1, padding=0)

        self.s1_2 = PzConv2d(n_in_channels, n_out_channels_1, kernel_size=1)
        self.pad0 = nn.ZeroPad2d([0, 1, 0, 1])
        self.pool0 = PzPool2d(kernel_size=2, stride=1, padding=0)


    def forward(self, x):

        x_s1_0 = self.s1_0(x)

        x_s1_2 = self.s1_2(x)

        x_pool0 = self.pool0(self.pad0(x_s1_2))

        output = torch.cat((x_s1_0, x_pool0), dim=1)

        return output


class NetWithInception(nn.Module):
    def __init__(self, n_input_channels, no_inception=False, debug=False):
        super(NetWithInception, self).__init__()
        
        # the number of bins to represent the output photo-z
        self.n_bins = 180

        self.debug = debug
        self.conv0 = PzConv2d(n_in_channels=n_input_channels,
                              n_out_channels=64,
                              kernel_size=5, padding=2)
        self.pool0 = PzPool2d(kernel_size=2, stride=2, padding=0)

        self.i0 = PzInception(n_in_channels=64,
                              n_out_channels_1=48,
                              n_out_channels_2=64)



        self.no_inception = no_inception

        if no_inception:
            self.fc0 = PzFullyConnected(n_inputs=65537, n_outputs=self.n_bins)
        else: 
            self.fc0 = PzFullyConnected(n_inputs=98305,n_outputs=self.n_bins)



    def num_flat_features(self, x):
        """

        Parameters
        ----------
        x: the input

        Returns
        -------
        the totale number of features = number of elements of the tensor except the batch dimension

        """
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

    def forward(self, x, reddening):

        x = self.conv0(x)
        x = self.pool0(x)

        if not self.no_inception:
            x = self.i0(x)
            
        flat = x.view(-1, self.num_flat_features(x))
        concat = torch.cat((flat, reddening), dim=1)
        fcn_in_features = concat.size(-1)
        x = self.fc0(concat)
        output = x


        return output

OK, so I have squeezed as much as possible the code:
It crash with 1.4.0 Pytorch version or nightly build (with a different message)

Exemple of Error on CPU

### Training model ###
> Parameters:
	batch_size: 2
	epochs: 5
	nloop: 5
Use device....:  cpu
Traceback (most recent call last):
  File "bugs_v4.py", line 117, in <module>
    main()
  File "bugs_v4.py", line 107, in main
    grad(g.sum(), new_img_batch)
  File "/..../anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 157, in grad
    inputs, allow_unused)
RuntimeError: tensor does not have a device

If someone can use the following stand-a-lone script to see if it rise an Error also in your configuration. Thanks.

import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad
import torch.nn.functional as F
import random
import types
import numpy as np

import os


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(-1,input.size(0))

class PzInception(nn.Module):

    def __init__(self, n_in_channels, n_out_channels_1):
        super(PzInception, self).__init__()

        self.s1_0 = nn.Conv2d(n_in_channels, n_out_channels_1,
                             kernel_size=1, padding=0, bias=True)
        self.activ10 = nn.PReLU(num_parameters=n_out_channels_1, init=0.25)

        self.s1_2 = nn.Conv2d(n_in_channels, n_out_channels_1,
                              kernel_size=1,  bias=True)
        self.activ12 = nn.PReLU(num_parameters=n_out_channels_1, init=0.25)
        

    def forward(self, x):

        x_s1_0 = self.activ10(self.s1_0(x))
        x_s1_2 = self.activ12(self.s1_2(x))

        return torch.cat((x_s1_0, x_s1_2), dim=1)


###
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch PhotoZ')
    parser.add_argument('--bs', type=int, default=2, dest='batch_size',
                        help='batch size (1:Ok; 2:Crash')

    parser.add_argument('--epochs', type=int, default=5, dest='epochs',
                        help='number of epochs')

    parser.add_argument('--nl', type=int, default=5, dest='nloop',
                        help='number of loops inside training')


    args = parser.parse_args()

    print("\n### Training model ###")
    print("> Parameters:")
    for p, v in zip(args.__dict__.keys(), args.__dict__.values()):
        print('\t{}: {}'.format(p, v))

    
    use_cuda = torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")
    print("Use device....: ",device)

    # Pin_memory speed up the CPU->GPU transfert for NVidia GPU
    kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}


    img_channels = 5 # number of filters
    
    model = nn.Sequential(
        nn.Conv2d(5,64,5, padding=2, bias=True),
        nn.PReLU(num_parameters=64, init=0.25),
        nn.AvgPool2d(kernel_size=2,stride=2,padding=0,
                     ceil_mode=True,count_include_pad=False),
        PzInception(64,48),
        Flatten(),
        nn.ReLU()  # if commented Ok
        )
    

    model.to(device)

    #### GO!

    model.train()
    # scans the batches
    Nloop = args.nloop
    img_H = 64
    img_W = 64
    n_batchs = args.batch_size

    for epoch in range(0, args.epochs + 1):
        

        for i_batch in range(Nloop):

            new_img_batch = torch.randn(n_batchs, img_channels,img_H ,img_W,dtyp
e=torch.float)

            new_img_batch =  new_img_batch.to(device)

            new_img_batch.requires_grad = True
            output = model(new_img_batch)
            g, = grad(output.sum(), new_img_batch, create_graph=True)
            grad(g.sum(), new_img_batch)

        print('Epoch {}'.format(epoch))


    #End
    print("End of job. Bye")

################################
if __name__ == '__main__':
  main()

Hi @albanD and @ptrblck

'sorry for the double post in @albanD mail box`

I have finally manage to squeeze as much as possible my code to report the problem

x= torch.randn(2,5,64,64,dtype=torch.float)
x.requires_grad = True
x1 = nn.Conv2d(5,64,5, padding=2, bias=True)(x)
x1 = nn.PReLU(64, init=0.25)(x1)
x2 = nn.Conv2d(5,32,3, padding=1, bias=True)(x)
x2 = nn.PReLU(32, init=0.25)(x2)
out = torch.cat((x1, x2), dim=1)
out = nn.Linear(64,1)(out)
out = nn.Sigmoid()(out)
g, = grad(out.sum(), x, create_graph=True)
grad(g.sum(), x)

Here is the report on CPU Pytorch 1.4.0

RuntimeError                              Traceback (most recent call last)
<ipython-input-109-a1fcf6095460> in <module>
      1 g, = grad(out.sum(), x, create_graph=True)
----> 2 grad(g.sum(), x)

~/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused)
    155     return Variable._execution_engine.run_backward(
    156         outputs, grad_outputs, retain_graph, create_graph,
--> 157         inputs, allow_unused)
    158 
    159 

RuntimeError: tensor does not have a device

I hope this short code satisfy your requirements of pb reporting :slight_smile:

The helps a lot thanks !

Looking at the graph this generates, it seems like the zero gradient here might not be handled properly by the SliceBackward that follows.

Let me double check and get back to you.

Good news :+1: now we have a starting point.

Facing the same issue. Except it’s only when doing jit.script on a module.

There’s absolutely nothing out of the ordinary in my module and I can’t get a MRE.

Hi @Enamex have you see RuntimeError: tensor does not have a device? #33037.

Have you solved the import error? I also had the error libtorch_python…

ImportError: /home/ubuntu/anaconda3/envs/myenv/lib/python3.6/site-packages/torch/lib/libtorch_python.so: undefined symbol: _ZN5torch4cuda4nccl6detail16throw_nccl_errorE12ncclResult_t