Combining Vgg16 and DenseNet

I am trying to combine vgg16 model with densenet121 model and I am having problems.
We are using 3×224×224 images.
I am a starter, so please explain clearly.
class MyEnsemble(nn.Module):
def init(self,modelA,modelB):
super(MyEnsemble,self).init()
self.modelA = modelA
self.modelB = modelB
self.classifier=nn.Linear(4,2)

  def forward(self,x1,x2):
        x1=self.modelA(x1)
        x2=self.modelB(x2)
        x=torch.cat((x1,x2),dim=1)
       x=self.classifier(F.relu(x))
  return x

model_combined=MyEnsemble(model_vgg16,model_densenet121)
x1,x2=torch.randn(1,256),torch.rand(1,64)
output=model_combined(x1,x2)

It gives Runtime Error: Expected object of device type cuda but got device type cpu for aegument #1 ‘self’ in call to _thnn_conv2d_forward

Are you passing a input with cuda into the model?

1 Like

I am using CUDA for training vgg16 and densenet121 models, so I think yes.

Can I see the code where you give the model the input. This error is happening because either the input or the model has not been moved to cuda. So you need to move either the input or model to cuda whichever you haven’t done yet.

1 Like

datadir= ‘./archive/chest_xray/chest_xray’
traindir = datadir + ‘/train/’
validdir = datadir + ‘/val/’
testdir = datadir + ‘/test/’

categories = []
img_categories = []
n_train = []
n_valid = []
n_test = []
hs = []
ws = []

for d in os.listdir(traindir):
if not d.startswith(’.’):
categories.append(d)

train_imgs = os.listdir(traindir + d)
valid_imgs = os.listdir(validdir + d)
test_imgs = os.listdir(testdir + d)
n_train.append(len(train_imgs))
n_valid.append(len(valid_imgs))
n_test.append(len(test_imgs))

    for i in train_imgs:
        if not i.startswith('.'):
            img_categories.append(d)
            img = Image.open(traindir + d + '/' + i)
            img_array = np.array(img)
            hs.append(img_array.shape[0])
            ws.append(img_array.shape[1])


cat_df = pd.DataFrame({‘category’: categories,
‘n_train’: n_train,
‘n_valid’: n_valid, ‘n_test’: n_test}).
sort_values(‘category’)

image_df = pd.DataFrame({
‘category’: img_categories,
‘height’: hs,
‘width’: ws
})

data = {
‘train’:
datasets.ImageFolder(root=traindir, transform=image_transforms[‘train’]),
‘val’:
datasets.ImageFolder(root=validdir, transform=image_transforms[‘val’]),
‘test’:
datasets.ImageFolder(root=testdir, transform=image_transforms[‘test’])
}

dataloaders = {
‘train’: DataLoader(data[‘train’], batch_size=batch_size, shuffle=True),
‘val’: DataLoader(data[‘val’], batch_size=batch_size, shuffle=True),
‘test’: DataLoader(data[‘test’], batch_size=batch_size, shuffle=True)
}
trainiter = iter(dataloaders[‘train’])
features, labels = next(trainiter)
features.shape, labels.shape
(torch.Size([20, 3, 224, 224]), torch.Size([20]))
n_classes = len(cat_df)

I am using Kaggle’s Chest XRay dataset. It looks like this in folder. chest_xray((train(NORMAL,PNEUMONIA)),(val(NORMAL,PNEUMONIA)),(test(NORMAL,PNEUMONIA)))

Do you call model.cuda() or model.to(device) in your code? And do you call inputs.cuda() or .to(device) ?

1 Like

I used model.to(‘cuda’).

ok so you have to also do that for the inputs. so do features.to(‘cuda’) and labels.to(‘cuda’) in you training loop.

1 Like
        for ii, (data, target) in enumerate(train_loader):
            # Tensors to gpu
            if train_on_gpu:
                data, target = data.cuda(), target.cuda()

I have done that before. I still get the error.

when you define the combined model are you passing it to cuda like this

model_combined.cuda()

Also which error is in the line no? Can you just print out the full error you get from your training loop.

1 Like

I have trained my Densenet and vgg model succesfully. After that,

class MyEnsemble(nn.Module):
    def _init_(self, modelA, modelB):
        super(MyEnsemble, self)._init_()
        self.modelA = modelA
        self.modelB = modelB
        self.classifier = nn.Linear(4, 2)
        
    def forward(self, x1, x2):
        x1 = self.modelA(x1)
        x2 = self.modelB(x2)
        x = torch.cat((x1, x2), dim=1)
        x = self.classifier(F.relu(x))
        return x

        model_ensemb = MyEnsemble(model_vgg, model_dense)
        x1, x2 = torch.randn(64,3,3,3), torch.randn(1, 64)
        output = model_ensemb(x1, x2)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-35-cea5a04b5109> in <module>
      1 model_ensemb = MyEnsemble(model_vgg, model_dense)
      2 x1, x2 = torch.randn(64,3,3,3), torch.randn(1, 64)
----> 3 output = model_ensemb(x1, x2)

c:\users\2115\appdata\local\programs\python\python38\lib\site-packages\torch\nn\modules\module.py in _call_(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

<ipython-input-24-40da828f3d9f> in forward(self, x1, x2)
      7 
      8     def forward(self, x1, x2):
----> 9         x1 = self.modelA(x1)
     10         x2 = self.modelB(x2)
     11         x = torch.cat((x1, x2), dim=1)

c:\users\2115\appdata\local\programs\python\python38\lib\site-packages\torch\nn\modules\module.py in _call_(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

c:\users\2115\appdata\local\programs\python\python38\lib\site-packages\torchvision\models\vgg.py in forward(self, x)
     41 
     42     def forward(self, x):
---> 43         x = self.features(x)
     44         x = self.avgpool(x)
     45         x = torch.flatten(x, 1)

c:\users\2115\appdata\local\programs\python\python38\lib\site-packages\torch\nn\modules\module.py in _call_(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

c:\users\2115\appdata\local\programs\python\python38\lib\site-packages\torch\nn\modules\container.py in forward(self, input)
     98     def forward(self, input):
     99         for module in self:
--> 100             input = module(input)
    101         return input
    102 

c:\users\2115\appdata\local\programs\python\python38\lib\site-packages\torch\nn\modules\module.py in _call_(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

c:\users\2115\appdata\local\programs\python\python38\lib\site-packages\torch\nn\modules\conv.py in forward(self, input)
    343 
    344     def forward(self, input):
--> 345         return self.conv2d_forward(input, self.weight)
    346 
    347 class Conv3d(_ConvNd):

c:\users\2115\appdata\local\programs\python\python38\lib\site-packages\torch\nn\modules\conv.py in conv2d_forward(self, input, weight)
    339                             weight, self.bias, self.stride,
    340                             _pair(0), self.dilation, self.groups)
--> 341         return F.conv2d(input, weight, self.bias, self.stride,
    342                         self.padding, self.dilation, self.groups)
    343 

RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _thnn_conv2d_forward

I am trying model_ensemb.cuda(). If that is the solution, I am going to feel really bad:D

Looking at the error that doesn’t look like the solution. What are you passing as x1 and x2 in the training loop?

1 Like

Ok wait the error in your code is happening because x1 and x2 are not initialized to cuda. You just define them like this

x1, x2 = torch.randn(64,3,3,3), torch.randn(1, 64)

you need to put those on cuda.

1 Like

So I just add this code,

x1 = torch.randn(64,3,3,3)
x1.to('cuda')

Is should be ok, right?
Thanks for your answers. I want to learn while doing something.

Yes that should work.

1 Like

Make sure to assign the tensor, as the to() operation won’t work inplace on tensors:

x1 = x1.to('cuda')

When I changed from x1 = torch.randn(64,3,3,3).to('cuda') to x1 = x1.to('cuda'), this error raise.

tensor([[[[ 1.3738e+00,  1.6170e+00,  3.4387e-01,  ...,  1.2356e+00,
            1.9411e+00,  6.9257e-02],
          [ 8.6033e-01, -1.2280e+00, -7.4821e-01,  ...,  6.2365e-02,
            2.0816e-01, -5.4607e-01],
          [-1.6347e-01,  1.2583e+00,  2.7074e+00,  ..., -1.3363e+00,
            1.2447e+00,  3.6719e-01],
          ...,
          [-1.1301e+00,  1.9237e-01,  2.2795e-02,  ..., -1.1974e+00,
           -9.7883e-01,  1.5822e+00],
          [-4.1928e-01, -6.5852e-01,  3.1132e-01,  ..., -2.0799e+00,
           -1.1829e-01,  3.1081e-01],
          [ 1.5343e+00, -1.2206e+00,  2.4044e-01,  ..., -8.0223e-01,
            1.3276e-01,  1.0792e+00]],

         [[ 1.6607e+00, -1.6649e+00, -2.1149e+00,  ..., -7.1301e-01,
            1.1041e+00,  5.7723e-03],
          [ 2.0768e+00,  2.3059e+00,  5.7799e-01,  ..., -3.1723e-01,
           -1.7190e-01, -2.0900e+00],
          [ 1.6640e+00,  1.2190e+00,  1.0463e+00,  ...,  1.4470e+00,
            4.3883e-01, -2.5957e-01],
          ...,
          [-9.3895e-01,  9.5411e-01, -6.8499e-01,  ...,  2.4570e+00,
            3.5444e-01, -1.5348e+00],
          [-5.0827e-01, -3.7860e-01, -5.7921e-01,  ...,  4.0725e-01,
           -2.1304e-01,  5.3209e-01],
          [ 2.3992e-01, -4.1098e-01,  3.3001e-01,  ..., -3.8462e-01,
           -8.5254e-01,  2.1340e-02]],

         [[-5.9497e-01, -2.0837e-01, -5.2536e-02,  ...,  1.3390e+00,
            1.4645e-01, -4.0928e-01],
          [-3.9733e+00, -1.9001e+00,  3.0652e-02,  ..., -6.3208e-01,
           -1.0671e-01,  8.1314e-01],
          [ 3.9714e-01, -1.6667e+00,  8.9182e-01,  ...,  4.2414e-01,
           -5.0485e-01, -8.9300e-01],
          ...,
          [-8.8482e-01,  3.3306e-01,  9.3603e-01,  ...,  9.6793e-01,
           -1.5524e+00, -7.0867e-01],
          [-9.6362e-01, -1.2317e-01,  8.5142e-01,  ...,  3.3439e-01,
           -1.0897e+00, -1.4645e+00],
          [ 5.2140e-01, -3.1554e-01,  1.1587e-01,  ..., -1.8543e-01,
           -8.1622e-01, -9.1532e-01]]],


        [[[ 1.1286e+00,  3.4704e+00, -8.2094e-01,  ...,  1.7327e+00,
           -4.2362e-01,  1.6319e+00],
          [ 6.1657e-01, -3.7871e-01,  3.4595e-01,  ...,  1.0394e+00,
            5.2830e-01, -3.0536e-02],
          [-2.1796e+00,  2.6964e-01,  2.3525e-01,  ..., -9.6463e-01,
           -1.4783e+00, -1.5843e+00],
          ...,
          [-3.2804e-01, -2.3166e+00, -1.9219e+00,  ...,  5.2085e-01,
           -2.8743e-01,  2.3273e-01],
          [-1.5173e+00,  1.2699e-01, -1.3232e-01,  ..., -1.0852e+00,
            1.1014e-02,  1.3600e+00],
          [-2.5993e-01,  1.0303e+00,  7.3986e-01,  ...,  1.0548e+00,
           -1.9268e-01,  1.6512e+00]],

         [[-5.3561e-02,  1.0735e+00,  1.0467e+00,  ...,  1.1428e+00,
           -9.2639e-01,  8.4774e-01],
          [-2.9379e-01, -1.0174e+00,  2.9117e+00,  ..., -1.2382e+00,
            1.4157e+00, -1.8998e+00],
          [ 1.8100e+00, -4.8207e-01,  1.1049e+00,  ..., -1.1973e-01,
           -8.9505e-01, -8.1042e-01],
          ...,
          [-1.0689e+00,  4.9736e-01,  1.2477e-01,  ...,  8.1400e-01,
            3.6573e-01, -1.1143e+00],
          [-1.1472e+00,  4.0190e-01, -2.8422e-02,  ..., -5.7513e-01,
            2.7724e-01,  1.2627e+00],
          [ 7.4237e-02,  1.0220e+00, -3.4520e-01,  ...,  1.8285e+00,
            1.0429e+00,  1.1795e+00]],

         [[ 1.7907e+00,  2.3012e-01, -3.5241e-01,  ...,  1.3356e+00,
           -8.4256e-01, -1.5806e+00],
          [ 6.1897e-01, -7.0960e-01,  1.1275e+00,  ...,  1.0590e+00,
           -2.9187e-01, -2.3310e-01],
          [ 5.0262e-01,  1.6449e+00, -2.5882e-01,  ..., -3.5218e-01,
            4.9902e-01,  7.2268e-01],
          ...,
          [-1.7285e-01,  4.6523e-01, -1.9443e-01,  ..., -1.5816e-01,
           -1.5420e+00,  1.3576e-01],
          [ 7.4234e-01, -2.7800e-01, -2.7359e-01,  ..., -3.3877e-01,
           -6.4417e-01, -6.6870e-01],
          [-5.5213e-01,  1.8822e+00,  2.0648e-01,  ...,  1.4829e+00,
           -1.4162e+00,  1.5254e-01]]],


        [[[-3.8275e-01,  8.7469e-01, -1.5699e+00,  ..., -1.6812e+00,
            7.8721e-02, -1.1511e+00],
          [ 1.5074e+00, -1.9959e+00, -3.4053e-01,  ..., -2.1573e+00,
           -2.2090e-01,  1.0753e+00],
          [ 1.2741e+00,  5.5767e-01,  2.2913e-01,  ..., -1.4749e+00,
           -1.2048e+00,  6.6352e-01],
          ...,
          [-1.5790e+00,  4.4862e-01, -2.4831e+00,  ...,  3.1354e-01,
            1.9212e+00, -6.8948e-01],
          [-1.5892e+00,  4.8342e-01,  1.5621e-01,  ..., -3.2899e-01,
           -2.6389e-01, -7.0431e-02],
          [-1.8056e-01,  2.5072e+00, -1.6480e+00,  ...,  7.0046e-01,
            4.5109e-01, -6.5162e-01]],

         [[ 2.4846e-01, -4.7023e-01,  1.4911e+00,  ..., -5.9566e-01,
           -1.9770e-01,  1.0534e+00],
          [ 5.3672e-01,  1.3595e-01, -7.3381e-01,  ...,  1.2827e+00,
           -5.1529e-01, -9.9328e-01],
          [-1.5824e-01, -7.9542e-01, -5.8690e-01,  ...,  2.1482e+00,
            2.8593e-01,  8.0048e-01],
          ...,
          [ 1.3970e+00,  8.5501e-02, -2.2210e-01,  ...,  7.7839e-01,
           -8.9441e-01, -6.2791e-01],
          [-1.2697e-01, -1.4650e-03, -1.0684e+00,  ...,  1.2182e+00,
            1.3189e-01, -2.9822e-01],
          [ 2.4211e-02,  1.8847e+00,  1.6538e+00,  ...,  2.3001e-02,
           -4.7337e-01,  7.4863e-01]],

         [[ 1.8260e-01,  5.3235e-01,  1.7497e+00,  ...,  1.6293e-01,
            4.6588e-01,  7.9873e-01],
          [-7.0679e-01, -3.8120e-01,  8.7782e-01,  ..., -1.5048e+00,
           -1.0785e+00,  6.5039e-01],
          [ 1.4715e+00, -1.5731e+00,  2.9692e-01,  ...,  3.8459e-01,
            1.8369e+00, -1.3718e+00],
          ...,
          [-2.4298e+00,  6.2362e-01, -4.0652e-01,  ..., -1.0254e+00,
            1.7386e+00,  7.6679e-01],
          [ 2.0310e-02, -8.6995e-01, -1.8608e+00,  ...,  1.0055e+00,
           -1.0898e+00,  1.5182e+00],
          [ 5.8819e-01,  1.1132e+00,  1.1230e+00,  ...,  1.8339e+00,
           -3.3747e-01, -7.6709e-01]]],


        ...,


        [[[-4.2959e-01, -1.8513e-01,  4.0738e+00,  ..., -9.8853e-02,
           -1.5748e+00,  3.3738e-01],
          [ 6.1330e-01,  1.8404e-01,  5.1972e-01,  ..., -3.0442e-01,
            6.8433e-01, -2.4026e-01],
          [ 2.4799e-01, -6.4763e-01, -3.4349e-02,  ...,  8.5829e-01,
            2.0179e+00, -1.1237e+00],
          ...,
          [ 6.2445e-01,  8.9477e-01,  1.3049e+00,  ...,  1.1319e+00,
           -6.1194e-01,  2.0845e+00],
          [ 1.0971e+00,  4.8079e-01, -2.2135e+00,  ...,  1.3091e-01,
            1.6692e+00,  4.1008e-01],
          [ 1.7660e+00,  3.0535e-01,  1.7169e-01,  ..., -1.0246e+00,
           -7.7203e-01, -2.3545e-01]],

         [[ 1.6164e-01,  1.2847e+00,  4.8990e-01,  ...,  5.1253e-01,
            2.8960e-01, -1.6328e-01],
          [-3.8200e-02,  2.4812e-01, -1.3049e+00,  ..., -3.6362e-02,
            1.2448e+00, -2.0191e-01],
          [-1.3291e+00, -1.2849e+00,  4.9813e-01,  ..., -9.2169e-01,
           -3.2721e-01,  3.0401e-01],
          ...,
          [ 1.3784e-01,  1.8349e+00,  1.0323e+00,  ..., -3.7623e-02,
            1.1804e+00,  7.5343e-02],
          [ 1.6813e-01,  6.4561e-01, -7.1513e-01,  ..., -8.8795e-01,
            1.6690e-01,  4.5098e-01],
          [ 1.7423e+00,  1.1711e+00,  2.7409e-01,  ...,  5.5110e-01,
            1.5460e+00,  1.2140e+00]],

         [[ 1.1153e+00, -7.5151e-01,  4.5126e-01,  ..., -7.7083e-01,
            1.7094e-01, -5.6937e-01],
          [-3.7550e-01, -2.3961e+00,  1.0611e+00,  ..., -9.6623e-01,
            2.5306e-01,  4.7367e-01],
          [ 4.1232e-01, -2.8777e-01,  2.5660e-01,  ...,  8.5573e-01,
            1.2591e+00,  9.9528e-01],
          ...,
          [ 9.3418e-01,  6.8896e-01, -2.2767e-01,  ..., -1.4142e+00,
           -5.7188e-01, -1.2695e-01],
          [-7.5224e-01, -1.4606e+00, -2.3310e-01,  ...,  1.3153e-01,
            1.1508e+00, -2.3101e-01],
          [ 1.4348e-01,  1.0346e+00, -1.3442e+00,  ...,  1.5408e+00,
           -1.6907e+00, -1.8329e+00]]],


        [[[ 3.7357e-01,  1.6450e+00,  1.6331e+00,  ..., -5.2163e-01,
            8.3834e-01,  3.3085e-01],
          [ 1.9456e-02,  4.9211e-04,  2.5807e+00,  ...,  8.1781e-02,
            6.6414e-01,  1.3460e+00],
          [ 1.4648e+00,  4.8832e-01,  1.2842e+00,  ...,  9.5314e-01,
           -9.8126e-01, -5.2367e-01],
          ...,
          [ 1.5011e+00, -1.5576e+00,  3.2545e-01,  ...,  3.5275e-01,
            2.6790e-01,  8.9674e-01],
          [ 3.8102e-01,  5.5054e-01,  6.9539e-01,  ...,  3.3880e-01,
           -1.3110e-01,  6.8279e-01],
          [ 4.0464e-01, -1.4554e-01,  3.7654e-01,  ..., -9.7156e-01,
           -6.5663e-01,  1.3549e+00]],

         [[-4.1428e-01, -1.1398e+00, -1.1962e-01,  ..., -3.7034e-01,
            8.9265e-01,  9.9764e-01],
          [-8.8790e-01,  6.5594e-01,  9.2387e-01,  ...,  9.1058e-01,
           -4.2757e-01, -3.8522e-01],
          [ 1.2396e+00,  8.5733e-01,  6.2273e-01,  ...,  4.4063e-02,
            1.1090e+00, -1.5577e+00],
          ...,
          [ 4.9486e-01, -1.0145e+00,  1.2137e-01,  ...,  1.7008e+00,
            7.8646e-01, -2.2173e-01],
          [-9.7285e-01,  1.4983e-01,  1.4303e+00,  ...,  1.7138e+00,
            1.5995e+00,  2.4027e+00],
          [ 4.0199e-01,  1.3294e+00,  1.5364e+00,  ..., -1.2339e+00,
           -1.3208e-01,  8.5768e-01]],

         [[-2.0088e+00,  8.9149e-01,  6.1231e-01,  ..., -4.3684e-02,
            1.7237e+00,  1.4563e-01],
          [ 8.1841e-01, -4.4185e-01, -9.4847e-01,  ..., -1.5341e-01,
           -9.1681e-01,  2.1837e+00],
          [-8.0557e-02,  6.3349e-01, -2.5978e+00,  ...,  1.3129e-01,
           -1.2524e+00, -1.0857e+00],
          ...,
          [ 1.0355e+00, -7.0865e-01, -3.0857e-01,  ...,  2.0053e+00,
           -2.4784e+00, -2.6277e+00],
          [ 9.8831e-01,  5.6888e-01,  1.6739e+00,  ...,  1.0684e+00,
           -2.8285e-01,  2.0879e-01],
          [-1.6054e-01,  8.0198e-01,  6.9564e-01,  ..., -1.0167e+00,
           -2.2463e+00, -2.4231e-01]]],


        [[[ 1.0925e+00, -8.6836e-01, -8.1109e-01,  ...,  5.4255e-01,
            3.5613e-01, -1.9508e-01],
          [-9.3833e-01,  7.9345e-01,  1.8179e-01,  ...,  2.1028e-01,
            9.0185e-02,  1.9536e-01],
          [ 1.7591e+00, -1.8813e+00,  7.5658e-01,  ..., -2.0238e-01,
            1.4689e+00, -1.2882e+00],
          ...,
          [ 1.0370e+00, -8.2894e-01,  1.3028e-02,  ...,  1.4212e+00,
           -1.5093e-01, -5.1726e-01],
          [-3.7768e-01,  1.2823e+00, -5.4269e-01,  ..., -1.8225e+00,
           -2.0913e-01,  1.8925e-01],
          [-1.4132e+00,  1.0184e+00,  5.2409e-01,  ..., -1.9754e-01,
           -6.0697e-01,  1.4865e+00]],

         [[-2.2553e-01,  8.7043e-01,  2.0090e-01,  ..., -1.3985e-01,
            1.4308e+00, -1.6350e+00],
          [-4.0821e-01, -3.5118e-01, -4.7131e-01,  ...,  4.0735e-01,
           -7.0374e-01, -1.4130e-01],
          [-7.3163e-01,  7.8971e-01, -2.0104e+00,  ..., -1.3375e+00,
           -5.7838e-01,  1.6582e+00],
          ...,
          [ 8.4811e-01,  4.6538e-01,  7.9469e-01,  ..., -1.6800e+00,
            1.0100e+00, -6.5119e-01],
          [ 2.0799e-01,  2.8471e+00,  4.2669e-01,  ..., -5.2057e-01,
           -7.3034e-01,  3.3939e-01],
          [-3.5454e-01,  3.2142e-01,  6.3146e-01,  ..., -2.4253e-01,
           -2.2102e-01,  4.6814e-01]],

         [[-1.5944e-01,  6.5645e-01,  1.2852e+00,  ..., -1.6736e-01,
           -1.0751e+00, -3.8616e-01],
          [-1.0053e+00, -5.1309e-01, -2.1615e-01,  ...,  4.9477e-01,
           -1.6984e-01, -4.5969e-01],
          [ 4.2761e-01, -2.1392e-01,  8.5068e-01,  ..., -3.0270e+00,
           -5.2644e-01, -1.5025e+00],
          ...,
          [ 5.4453e-01,  8.0921e-01, -1.6397e+00,  ...,  3.6729e-01,
            1.6886e-01, -9.3092e-01],
          [-8.6155e-01,  1.8004e+00,  3.3496e-01,  ...,  8.4892e-01,
           -2.0713e-01, -1.0664e+00],
          [-7.0426e-02,  1.4608e+00,  6.8907e-01,  ..., -6.2854e-01,
            1.5516e-01,  5.6305e-01]]]], device='cuda:0')tensor([[[[ 1.9630e+00,  3.8404e-01, -5.5180e-01,  ...,  8.5129e-01,
           -2.7766e-01, -9.5084e-01],
          [ 1.1394e+00, -3.8584e-01,  6.1361e-01,  ..., -2.9312e-01,
           -6.4726e-01,  1.4930e+00],
          [-9.8322e-01,  2.1182e+00,  3.7215e-01,  ...,  1.2371e+00,
            5.9932e-01, -1.4091e+00],
          ...,
          [-1.6244e+00,  1.3689e+00, -1.6173e-01,  ...,  1.8017e+00,
           -7.7667e-01,  1.4552e+00],
          [-2.9311e+00, -1.4496e+00, -1.5120e+00,  ...,  5.0621e-01,
            3.4139e-01,  3.4854e-01],
          [ 3.6875e-01,  2.5200e-01,  1.0731e+00,  ..., -4.6104e-01,
            6.0048e-01,  1.1450e-02]],

         [[-4.4875e-01, -1.7632e+00, -6.7356e-01,  ...,  5.9062e-01,
           -5.3563e-01, -9.2294e-01],
          [-1.3688e+00, -1.0022e+00,  1.5060e+00,  ..., -5.8654e-02,
            2.0467e+00,  4.7006e-01],
          [ 1.8755e+00,  3.0471e-01, -1.5458e+00,  ..., -1.6941e+00,
            6.3264e-01,  6.3953e-01],
          ...,
          [-1.7638e+00,  2.4371e+00,  9.5417e-01,  ...,  8.5064e-01,
           -9.5826e-01,  1.5330e+00],
          [-1.1763e+00,  1.2015e-01,  5.5135e-01,  ...,  9.8510e-03,
            9.6703e-01, -1.2908e+00],
          [ 1.1870e+00,  4.6432e-01,  3.0974e-01,  ..., -6.3579e-01,
            4.3054e-01,  1.1779e+00]],

         [[ 5.7441e-01,  1.4737e+00,  1.0305e+00,  ..., -9.9783e-01,
           -1.7123e+00, -8.0892e-01],
          [ 1.6251e+00, -1.3617e+00, -3.1943e-01,  ...,  8.4235e-01,
           -2.9847e-01, -2.3213e-01],
          [ 7.4568e-01,  7.8106e-01, -1.4361e+00,  ..., -4.9696e-01,
            1.3868e+00, -3.4384e-01],
          ...,
          [ 8.0754e-01, -5.7615e-01,  7.3965e-01,  ..., -5.0789e-01,
           -1.0228e+00,  4.3835e-01],
          [-1.0120e-01,  1.0970e+00, -4.3100e-01,  ...,  4.7326e-02,
            3.8954e-01, -8.3568e-01],
          [-2.8438e-01,  1.6836e+00, -4.2749e-01,  ...,  1.5726e+00,
           -1.5728e+00, -2.9523e-01]]],


        [[[-8.2771e-01, -4.1231e-01, -6.2533e-02,  ..., -6.0629e-01,
           -6.4560e-02,  1.5112e-01],
          [-1.3803e-01,  1.7187e+00, -9.5359e-01,  ..., -1.3095e+00,
            1.4077e+00,  4.1182e-01],
          [-3.0824e-01,  1.1287e-01, -1.9721e+00,  ...,  2.6124e-01,
           -4.3414e-01, -8.7193e-01],
          ...,
          [-2.0622e+00, -4.7485e-01,  1.6385e+00,  ...,  1.6190e+00,
            7.5018e-01,  1.0480e-01],
          [-1.0235e+00, -7.8345e-01, -4.8446e-02,  ..., -1.4287e+00,
            7.4898e-01, -9.6946e-01],
          [ 1.4209e+00, -1.1821e+00,  5.9781e-01,  ...,  1.9127e-01,
            7.5436e-01, -9.6033e-01]],

         [[ 1.2246e+00,  6.4567e-02, -1.4527e+00,  ...,  6.5273e-01,
            4.2086e-01,  1.5603e+00],
          [-6.2193e-01,  1.8753e+00,  2.8092e-01,  ..., -7.4377e-01,
            2.4196e+00, -3.8298e-01],
          [ 4.0220e-01,  1.7046e-03, -9.1700e-01,  ..., -2.9472e+00,
            1.9298e+00, -9.5318e-01],
          ...,
          [-1.6360e-01, -3.8656e-01,  1.0750e+00,  ..., -8.1048e-03,
            4.0365e-01,  3.6127e-01],
          [ 4.8787e-01,  3.6298e-01, -4.0933e-01,  ...,  1.0163e+00,
           -6.4409e-01, -3.0331e-01],
          [-1.0394e+00,  3.5060e-01,  1.6686e+00,  ..., -5.6064e-02,
           -5.1875e-01,  1.2882e+00]],

         [[-2.2759e+00,  2.5129e-01, -1.4601e-01,  ...,  4.3364e-02,
            5.8758e-01,  1.9097e+00],
          [ 3.4251e-01,  1.9955e-01, -1.2129e+00,  ..., -4.5545e-01,
            1.4100e+00, -4.2948e-01],
          [-2.0770e-01, -7.7857e-02,  1.3584e-01,  ..., -1.0823e+00,
           -1.8285e-01, -1.7177e+00],
          ...,
          [ 5.4050e-01, -8.5378e-01,  7.5097e-01,  ..., -4.4317e-01,
           -1.8126e-01,  2.0159e+00],
          [ 1.3528e+00,  1.3936e-01, -3.8658e-01,  ..., -3.7301e-01,
            4.3743e-01, -3.2276e-01],
          [-4.9224e-01, -5.7551e-01,  1.3299e+00,  ..., -1.1900e+00,
           -9.9705e-01, -1.5391e+00]]],


        [[[-4.6765e-01,  1.7469e+00,  4.1167e-01,  ..., -2.9727e-01,
            1.0069e+00, -4.3946e-01],
          [ 1.2865e-01,  3.2716e-01, -1.0923e-01,  ..., -9.0709e-01,
           -2.7334e-02,  1.9984e-02],
          [-6.6735e-02, -2.5906e-01, -8.5617e-01,  ..., -1.5415e+00,
           -1.0417e+00, -9.9276e-02],
          ...,
          [ 5.0312e-01, -9.3734e-02,  1.4636e+00,  ...,  6.6212e-02,
            2.8279e+00,  1.9585e-01],
          [-6.6980e-01, -8.6900e-01,  1.5784e+00,  ...,  6.7204e-01,
           -6.9010e-01,  2.0303e-01],
          [ 1.2830e+00, -1.5911e-01,  1.3061e+00,  ...,  2.4682e-01,
            2.4999e-01,  1.0558e-01]],

         [[-2.5318e+00,  7.9681e-01, -8.8224e-01,  ..., -7.2539e-01,
           -1.4089e+00,  3.1891e-01],
          [ 4.9546e-01, -6.2545e-02,  7.4707e-01,  ...,  4.5729e-01,
           -1.2813e-01, -5.9618e-01],
          [ 6.7128e-02,  4.6026e-01,  3.2161e-01,  ..., -7.8101e-02,
            1.4282e+00,  1.3884e+00],
          ...,
          [ 1.1093e+00, -2.1053e-01, -7.0763e-02,  ...,  6.2094e-01,
            1.0368e+00, -9.6021e-01],
          [ 4.3935e-01, -1.5691e+00, -1.6600e-01,  ...,  1.6468e+00,
           -5.5486e-01, -1.2749e-01],
          [ 2.6243e-01,  6.5141e-01,  4.5039e-01,  ...,  2.0040e-01,
            1.3858e+00, -1.1163e+00]],

         [[-4.8570e-02, -1.1930e+00, -1.1661e+00,  ..., -6.5790e-01,
           -4.5344e-01,  1.0351e+00],
          [-8.0445e-01, -3.8162e-01,  1.1300e+00,  ..., -5.1038e-01,
           -4.8811e-01, -2.3436e+00],
          [-4.6834e-01, -1.6518e-01,  7.2825e-01,  ...,  1.2359e+00,
           -8.2068e-02, -1.6591e-02],
          ...,
          [ 5.1922e-01,  1.3415e+00, -1.9047e+00,  ...,  1.7332e+00,
           -7.3272e-02,  3.4200e-01],
          [ 1.4703e+00, -1.3090e+00,  1.5598e+00,  ..., -1.9445e+00,
           -7.9541e-01, -3.4640e-01],
          [-1.6008e-01, -5.2161e-01,  7.6432e-01,  ..., -1.2405e+00,
           -9.2732e-01,  5.7209e-01]]],


        ...,


        [[[-6.8262e-01,  4.7326e-01,  4.2971e-01,  ...,  1.7691e+00,
           -3.9737e-01, -8.3816e-01],
          [ 4.9984e-01, -1.3731e-01, -7.8729e-01,  ..., -9.9463e-01,
           -5.8860e-01, -8.2637e-01],
          [ 8.5733e-01,  1.0303e+00, -2.6355e-02,  ...,  9.2787e-01,
           -1.3474e+00,  5.6791e-01],
          ...,
          [-6.8053e-01, -1.4071e+00, -1.4041e-02,  ...,  6.6805e-01,
           -2.1749e+00,  6.5959e-01],
          [-6.1508e-01, -1.2307e+00,  1.6682e+00,  ...,  8.9499e-01,
            1.8058e+00,  3.5958e-01],
          [ 1.4247e+00, -5.1951e-01,  7.4223e-01,  ...,  1.5904e+00,
           -1.9811e+00,  1.5049e+00]],

         [[-3.4248e-01,  1.1882e+00,  3.4814e-01,  ...,  3.5260e-01,
           -9.0572e-01, -3.7867e-01],
          [ 9.5852e-01, -1.7989e+00,  6.3274e-01,  ...,  1.1048e+00,
           -1.1374e+00,  1.0441e+00],
          [-6.8728e-01, -3.5182e-02, -4.7950e-01,  ..., -4.6141e-01,
           -1.4892e+00, -1.6219e+00],
          ...,
          [-6.3503e-01, -6.2600e-02, -1.0882e+00,  ...,  7.6589e-02,
           -9.0240e-01, -3.5643e-01],
          [-9.4401e-01, -9.8296e-01,  2.3919e-01,  ...,  3.1500e-01,
           -1.5620e+00, -5.0065e-01],
          [ 5.5152e-01, -6.1716e-01, -1.5657e+00,  ..., -4.7486e-01,
           -2.5323e-01, -4.9726e-01]],

         [[-8.6983e-01,  8.8978e-01,  2.5706e+00,  ..., -1.9899e+00,
            7.8805e-01, -5.0732e-01],
          [-1.2614e+00, -5.7335e-01, -5.9546e-01,  ..., -1.1494e+00,
           -6.2682e-02, -5.0015e-01],
          [ 1.9901e-02,  5.5815e-01,  1.4817e+00,  ..., -7.7306e-02,
           -6.5690e-01,  2.7882e-01],
          ...,
          [ 1.7375e-01, -8.3446e-01,  9.3299e-01,  ..., -5.9919e-01,
            1.5727e+00, -1.8054e+00],
          [ 8.9174e-01,  1.1397e+00, -1.8798e+00,  ..., -5.6368e-01,
            5.9911e-01, -6.1682e-03],
          [ 2.8880e-03, -9.8632e-01,  1.6279e+00,  ...,  9.0535e-01,
            5.7768e-01, -1.2033e+00]]],


        [[[ 1.3349e+00, -2.3353e-01, -1.1031e+00,  ..., -1.1516e+00,
            1.0387e+00, -3.9853e-01],
          [ 7.7566e-01, -4.4274e-01, -1.7194e+00,  ...,  3.2070e-01,
           -8.4431e-01, -8.4735e-01],
          [ 1.9050e+00, -7.4705e-01, -5.4816e-01,  ..., -6.9642e-01,
           -4.2740e-01, -1.7113e+00],
          ...,
          [-9.0888e-01, -1.9023e-01,  3.0928e-01,  ..., -4.0463e-01,
            1.8053e+00,  2.6060e-01],
          [ 1.5428e-01,  1.9307e+00, -7.9099e-01,  ...,  2.3353e-01,
           -1.6791e-01,  6.1215e-01],
          [ 6.6591e-01,  3.0180e-01, -5.3162e-01,  ..., -2.3198e-01,
            9.5196e-01, -1.0774e+00]],

         [[-1.2617e-01,  1.0789e+00, -1.1568e+00,  ..., -5.0699e-02,
           -9.5387e-01, -1.4903e-01],
          [ 1.2080e+00,  5.5068e-01,  6.8881e-01,  ..., -4.6396e-01,
           -1.0738e-01, -1.5264e-01],
          [-1.1720e-03, -8.3440e-01, -1.6276e-01,  ..., -2.3065e+00,
            1.2214e+00,  1.8779e+00],
          ...,
          [ 2.2949e-01,  4.0339e-01, -9.7314e-01,  ...,  6.9415e-02,
            2.6810e-01,  9.8845e-01],
          [ 1.6562e-01, -3.8893e-01,  1.8187e+00,  ...,  8.3348e-01,
           -8.9611e-01, -1.3435e-01],
          [-2.7845e-01,  6.1842e-01, -8.6288e-01,  ...,  3.3269e-01,
            1.1248e+00,  1.4452e+00]],

         [[-8.4922e-03,  9.8480e-01,  7.8474e-01,  ..., -1.7979e-01,
           -5.5188e-01,  5.2757e-01],
          [ 1.5363e+00, -7.0678e-01, -1.0823e+00,  ...,  8.7873e-01,
            4.4479e-01, -1.1418e-01],
          [-6.5378e-01, -6.2934e-01,  5.2322e-01,  ..., -1.1436e-01,
           -6.7200e-01, -1.7398e+00],
          ...,
          [-1.1200e-01,  1.4083e-01,  9.3912e-01,  ...,  1.7179e+00,
           -1.1647e+00, -5.1535e-01],
          [ 4.1458e-01, -1.2681e+00, -4.9710e-02,  ...,  5.0798e-01,
           -2.6035e+00, -1.8036e-01],
          [ 2.1277e-01, -7.1593e-01,  1.1685e+00,  ..., -1.2260e+00,
            2.3016e-01, -7.6723e-02]]],


        [[[-1.3442e-01, -7.8200e-01,  8.6340e-02,  ...,  2.1709e-02,
            1.2761e+00,  2.0928e-01],
          [ 9.0745e-02,  1.1726e+00, -5.5066e-01,  ..., -1.5055e+00,
           -2.6747e-02,  2.8691e-01],
          [-1.1133e+00,  7.2821e-01,  2.2393e+00,  ..., -5.2579e-01,
           -4.7290e-01,  1.9526e-01],
          ...,
          [-7.9382e-01, -8.9846e-01, -1.0758e+00,  ..., -1.0854e+00,
            1.1178e-01,  9.1820e-01],
          [-1.1788e+00, -6.8295e-01,  6.5025e-01,  ..., -7.1424e-01,
           -8.0331e-01,  5.0874e-01],
          [-1.4446e+00, -1.2125e+00, -1.7905e+00,  ...,  8.3627e-01,
           -1.2504e-01, -2.1789e+00]],

         [[-1.2523e+00,  3.6981e-01, -9.1756e-01,  ..., -1.0783e+00,
           -6.6086e-01,  3.4098e-01],
          [-7.4789e-01,  3.0540e-01, -1.2102e-01,  ..., -2.1479e+00,
            4.0967e-01,  2.3575e+00],
          [-4.8896e-01,  2.5413e-01,  5.3710e-01,  ...,  8.0420e-01,
            9.4549e-01,  9.3925e-01],
          ...,
          [-1.0907e+00, -1.6896e+00,  2.8979e-01,  ...,  4.4099e-01,
           -3.3184e+00, -1.3236e-01],
          [ 5.5252e-01, -1.3304e+00, -3.7177e-01,  ...,  1.1142e+00,
            5.3729e-01,  1.5201e-02],
          [ 1.4826e-01, -1.2266e+00, -3.6960e-01,  ..., -3.2862e-01,
            1.6755e+00,  4.4076e-01]],

         [[-1.0729e+00, -7.6326e-01, -1.4185e-01,  ...,  3.4313e-01,
            1.7468e+00,  1.3469e-01],
          [-5.2029e-02, -4.3576e-01,  2.2667e-01,  ..., -1.1486e+00,
           -1.7975e-03,  1.2411e+00],
          [ 4.4111e-01,  2.5796e-01, -4.2752e-02,  ...,  6.5107e-01,
            1.7516e+00, -1.6947e-01],
          ...,
          [-3.6919e-01,  2.5534e-01,  2.2259e+00,  ...,  2.4216e-01,
            6.4547e-01, -6.2078e-01],
          [-5.5642e-01,  4.4211e-01,  1.6152e-01,  ..., -1.5258e+00,
           -2.7078e+00,  1.5464e+00],
          [-5.8826e-01, -1.1390e+00, -7.2340e-01,  ..., -7.3673e-01,
           -1.1131e+00, -8.6268e-01]]]], device='cuda:0')
RuntimeError                              Traceback (most recent call last)
<ipython-input-52-89a8c9066ce8> in <module>()
      4 x1.to('cuda')
      5 x2.to('cuda')
----> 6 output = model_ensemb(x1, x2)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in conv2d_forward(self, input, weight)
    340                             _pair(0), self.dilation, self.groups)
    341         return F.conv2d(input, weight, self.bias, self.stride,
--> 342                         self.padding, self.dilation, self.groups)
    343 
    344     def forward(self, input):

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same

Based on the stacktrace, it seems you are not using this approach:

      4 x1.to('cuda')
      5 x2.to('cuda')
----> 6 output = model_ensemb(x1, x2)

and do not reassign x1 and x2.

1 Like

How can I do that? I am using this->

model_ensemb = MyEnsemble(model_vgg, model_dense)
model_ensemb.to('cuda')
x1, x2 = torch.randn(20, 3,224,224), torch.randn(20,3,224,224)
x1.to('cuda')
x2.to('cuda')
output = model_ensemb(x1, x2)

Edit->Thanks for your answer, I got it.

Last error is solved, but this time when I

print(output)
tensor([[0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893],
        [0.3079, 0.3893]], device='cuda:0', grad_fn=<AddmmBackward>)

Is this normal?