RuntimeError: tensor does not have a device?

Hi, I have some new materials

import torch
import torch.nn as nn
from torch.autograd import grad
import numpy as np

class PzConv2d(nn.Module):
    """ Convolution 2D Layer followed by PReLU activation
    """
    def __init__(self, n_in_channels, n_out_channels, **kwargs):
        super(PzConv2d, self).__init__()
        self.conv = nn.Conv2d(n_in_channels, n_out_channels, bias=True,
                            **kwargs)
        nn.init.xavier_uniform_(self.conv.weight)
        nn.init.constant_(self.conv.bias,0.1)
        self.activ = nn.PReLU(num_parameters=n_out_channels, init=0.25)

    def forward(self, x):
        x = self.conv(x)
        return self.activ(x)


class PzPool2d(nn.Module):
    """ Average Pooling Layer
    """
    def __init__(self, kernel_size, stride, padding=0):
        super(PzPool2d, self).__init__()
        self.pool = nn.AvgPool2d(kernel_size=kernel_size,
                                 stride=stride,
                                 padding=padding,
                                 ceil_mode=True,
                                 count_include_pad=False)

    def forward(self, x):
        return self.pool(x)


class PzFullyConnected(nn.Module):
    """ Dense or Fully Connected Layer followed by ReLU
    """
    def __init__(self, n_inputs, n_outputs, withrelu=True, **kwargs):
        super(PzFullyConnected, self).__init__()
        self.withrelu = withrelu
        self.linear = nn.Linear(n_inputs, n_outputs, bias=True)
        nn.init.xavier_uniform_(self.linear.weight)
        nn.init.constant_(self.linear.bias, 0.1)
        self.activ = nn.ReLU()

    def forward(self, x):
        x = self.linear(x)
        if self.withrelu:
            x = self.activ(x)
        return x


class PzInception(nn.Module):
    """ Inspection module

        The input (x) is dispatched between

        o a cascade of conv layers s1_0 1x1 , s2_0 3x3
        o a cascade of conv layer s1_2 1x1, followed by pooling layer pool0 2x2
        o a cascade of conv layer s2_2 1x1
        o optionally a cascade of conv layers s1_1 1x1, s2_1 5x5

        then the 3 (or 4) intermediate outputs are concatenated
    """
    def __init__(self, n_in_channels, n_out_channels_1, n_out_channels_2,
                 without_kernel_5=False, debug=False):
        super(PzInception, self).__init__()
        self.debug = debug
        self.s1_0 = PzConv2d(n_in_channels, n_out_channels_1,
                             kernel_size=1, padding=0)
        self.s2_0 = PzConv2d(n_out_channels_1, n_out_channels_2,
                             kernel_size=3, padding=1)

        self.s1_2 = PzConv2d(n_in_channels, n_out_channels_1, kernel_size=1)
        self.pad0 = nn.ZeroPad2d([0, 1, 0, 1])
        self.pool0 = PzPool2d(kernel_size=2, stride=1, padding=0)

        self.without_kernel_5 = without_kernel_5
        if not (without_kernel_5):
            self.s1_1 = PzConv2d(n_in_channels, n_out_channels_1,
                                 kernel_size=1, padding=0)
            self.s2_1 = PzConv2d(n_out_channels_1, n_out_channels_2,
                                 kernel_size=5, padding=2)

        self.s2_2 = PzConv2d(n_in_channels, n_out_channels_2, kernel_size=1,
                             padding=0)

    def forward(self, x):
        # x:image tenseur N_batch, Channels, Height, Width
        x_s1_0 = self.s1_0(x)
        x_s2_0 = self.s2_0(x_s1_0)

        x_s1_2 = self.s1_2(x)

        x_pool0 = self.pool0(self.pad0(x_s1_2))

        if not (self.without_kernel_5):
            x_s1_1 = self.s1_1(x)
            x_s2_1 = self.s2_1(x_s1_1)

        x_s2_2 = self.s2_2(x)

        if self.debug: print("Inception x_s1_0  :", x_s1_0.size())
        if self.debug: print("Inception x_s2_0  :", x_s2_0.size())
        if self.debug: print("Inception x_s1_2  :", x_s1_2.size())
        if self.debug: print("Inception x_pool0 :", x_pool0.size())

        if not (self.without_kernel_5) and self.debug:
            print("Inception x_s1_1  :", x_s1_1.size())
            print("Inception x_s2_1  :", x_s2_1.size())

        if self.debug: print("Inception x_s2_2  :", x_s2_2.size())

        # to be check: dim=1=> NCHW (en TensorFlow axis=3 NHWC)
        if not (self.without_kernel_5):
            output = torch.cat((x_s2_2, x_s2_1, x_s2_0, x_pool0), dim=1)
        else:
            output = torch.cat((x_s2_2, x_s2_0, x_pool0), dim=1)

        if self.debug: print("Inception output :", output.shape)
        return output


class NetWithInception(nn.Module):
    """ The Networks
        inputs: the image (x), the reddening vector


        The image 64x64x5 is fed forwardly throw
        o a conv layer 5x5
        o a pooling layer 2x2
        o 5 inspection modules with the last one including a 5x5 part

        Then, we concatenate the result with the reddening vector to perform
        o 3 fully connected layers

        The output dimension is given by n_bins
        There is no activation softmax here to allow the use of Cross Entropy loss

    """
    def __init__(self, n_input_channels, debug=False):
        super(NetWithInception, self).__init__()
        
        # the number of bins to represent the output photo-z
        self.n_bins = 180

        self.debug = debug
        self.conv0 = PzConv2d(n_in_channels=n_input_channels,
                              n_out_channels=64,
                              kernel_size=5, padding=2)
        self.pool0 = PzPool2d(kernel_size=2, stride=2, padding=0)
        # for the Softmax the input tensor shape is [1,n] so apply on axis=1
        # t1 = torch.rand([1,10])
        # t2 = nn.Softmax(dim=1)(t1)
        # torch.sum(t2) = 1
        self.i0 = PzInception(n_in_channels=64,
                              n_out_channels_1=48,
                              n_out_channels_2=64)

        self.i1 = PzInception(n_in_channels=240,
                              n_out_channels_1=64,
                              n_out_channels_2=92)

        self.i2 = PzInception(n_in_channels=340,
                              n_out_channels_1=92,
                              n_out_channels_2=128)

        self.i3 = PzInception(n_in_channels=476,
                              n_out_channels_1=92,
                              n_out_channels_2=128)

        self.i4 = PzInception(n_in_channels=476,
                              n_out_channels_1=92,
                              n_out_channels_2=128,
                              without_kernel_5=True)

        self.fc0 = PzFullyConnected(n_inputs=22273, n_outputs=1096)
        self.fc1 = PzFullyConnected(n_inputs=1096, n_outputs=1096)
        self.fc2 = PzFullyConnected(n_inputs=1096, n_outputs=self.n_bins)


    def num_flat_features(self, x):
        """

        Parameters
        ----------
        x: the input

        Returns
        -------
        the totale number of features = number of elements of the tensor except the batch dimension

        """
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

    def forward(self, x, reddening):
        # x:image tenseur N_batch, Channels, Height, Width
        #    size N, Channles=5 filtres, H,W = 64 pixels
        # save original image
        x_in = x

        if self.debug: print("input shape: ", x.size())
        x = self.conv0(x)
        if self.debug: print("conv0 shape: ", x.size())
        x = self.pool0(x)
        if self.debug: print("conv0p shape: ", x.size())
        if self.debug: print('>>>>>>> i0:START <<<<<<<')
        x = self.i0(x)
        if self.debug: print("i0 shape: ", x.size())

        if self.debug: print('>>>>>>> i1:START <<<<<<<')
        x = self.i1(x)

        x = self.pool0(x)
        if self.debug: print("i1p shape: ", x.size())

        if self.debug: print('>>>>>>> i2:START <<<<<<<')
        x = self.i2(x)
        if self.debug: print("i2 shape: ", x.size())

        if self.debug: print('>>>>>>> i3:START <<<<<<<')
        x = self.i3(x)
        x = self.pool0(x)
        if self.debug: print("i3p shape: ", x.size())

        if self.debug: print('>>>>>>> i4:START <<<<<<<')
        x = self.i4(x)
        if self.debug: print("i4 shape: ", x.size())

        if self.debug: print('>>>>>>> FC part :START <<<<<<<')
        flat = x.view(-1, self.num_flat_features(x))
        if self.debug: print("flat shape: ", flat.size())
        concat = torch.cat((flat, reddening), dim=1)
        if self.debug: print('concat shape: ', concat.size())

        fcn_in_features = concat.size(-1)
        if self.debug: print('fcn_in_features: ', fcn_in_features)

        x = self.fc0(concat)
        if self.debug: print('fc0 shape: ', x.size())
        x = self.fc1(x)
        if self.debug: print('fc1 shape: ', x.size())
        x = self.fc2(x)
        if self.debug: print('fc2 shape: ', x.size())

        output = x
        if self.debug: print('output shape: ', output.size())

        #params = {"output": output, "x": x_in, "reddening": reddening}
        # return params

        return output

########

img_channels = 5
img_H = 64
img_W = 64
n_batchs = 1

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print("device: ",device)

loss_fn = torch.nn.CrossEntropyLoss()
loss_fn_sum = torch.nn.CrossEntropyLoss(reduction='sum')
#m = torch.nn.Linear(20, 30)

m = NetWithInception(img_channels,debug=True)
m.to(device)

ims = torch.randn(n_batchs, img_channels,img_H ,img_W,dtype=torch.float)
reds = torch.zeros([n_batchs,1],dtype=torch.float)
target = torch.empty(n_batchs, dtype=torch.long).random_(180)
ims, target, reds = ims.to(device), target.to(device), reds.to(device)

pred =  m(ims, reds)
loss = loss_fn(pred,target)

n = ims.shape[0]
imsv = ims.clone().requires_grad_()
preds = m(imsv, reds)
xeloss = loss_fn(preds, target)
g, = torch.autograd.grad(xeloss, imsv, create_graph=True)
penalty = g.norm(1) / n


print("loss   : ",loss)
print("penalty: ",penalty)

loss += penalty
tmp = loss.item()
print("loss   : ",loss)
#compute grad
loss.backward()
print("end")

Can someone run this code and tell me if he gets the following message:

device:  cuda
input shape:  torch.Size([1, 5, 64, 64])
conv0 shape:  torch.Size([1, 64, 64, 64])
conv0p shape:  torch.Size([1, 64, 32, 32])
>>>>>>> i0:START <<<<<<<
i0 shape:  torch.Size([1, 240, 32, 32])
>>>>>>> i1:START <<<<<<<
i1p shape:  torch.Size([1, 340, 16, 16])
>>>>>>> i2:START <<<<<<<
i2 shape:  torch.Size([1, 476, 16, 16])
>>>>>>> i3:START <<<<<<<
i3p shape:  torch.Size([1, 476, 8, 8])
>>>>>>> i4:START <<<<<<<
i4 shape:  torch.Size([1, 348, 8, 8])
>>>>>>> FC part :START <<<<<<<
flat shape:  torch.Size([1, 22272])
concat shape:  torch.Size([1, 22273])
fcn_in_features:  22273
fc0 shape:  torch.Size([1, 1096])
fc1 shape:  torch.Size([1, 1096])
fc2 shape:  torch.Size([1, 180])
output shape:  torch.Size([1, 180])
input shape:  torch.Size([1, 5, 64, 64])
conv0 shape:  torch.Size([1, 64, 64, 64])
conv0p shape:  torch.Size([1, 64, 32, 32])
>>>>>>> i0:START <<<<<<<
i0 shape:  torch.Size([1, 240, 32, 32])
>>>>>>> i1:START <<<<<<<
i1p shape:  torch.Size([1, 340, 16, 16])
>>>>>>> i2:START <<<<<<<
i2 shape:  torch.Size([1, 476, 16, 16])
>>>>>>> i3:START <<<<<<<
i3p shape:  torch.Size([1, 476, 8, 8])
>>>>>>> i4:START <<<<<<<
i4 shape:  torch.Size([1, 348, 8, 8])
>>>>>>> FC part :START <<<<<<<
flat shape:  torch.Size([1, 22272])
concat shape:  torch.Size([1, 22273])
fcn_in_features:  22273
fc0 shape:  torch.Size([1, 1096])
fc1 shape:  torch.Size([1, 1096])
fc2 shape:  torch.Size([1, 180])
output shape:  torch.Size([1, 180])
loss   :  tensor(5.4763, device='cuda:0', grad_fn=<NllLossBackward>)
penalty:  tensor(0.1425, device='cuda:0', grad_fn=<DivBackward0>)
loss   :  tensor(5.6188, device='cuda:0', grad_fn=<AddBackward0>)
Traceback (most recent call last):
  File "bugs.py", line 301, in <module>
    loss.backward()
  File "...anaconda3/lib/python3.7/site-packages/torch/tensor.py", line 166, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File ".../anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: sizes() called on undefined Tensor

It is a different error… but may be linked to the original one.
Same problem on “cpu” or “cuda”