RuntimeError: Tensor for argument #2 'weight' is on CPU, but expected it to be on GPU (while checking arguments for cudnn_batch_norm)

Hi !, I just started training my model with PyTorch. However, I met a problem when I trained my model with GPU. The error message is shown below.
RuntimeError: Tensor for argument #2 ‘weight’ is on CPU, but expected it to be on GPU (while checking arguments for cudnn_batch_norm)
And this is my train.py

import torch
import torch.nn as nn
import torch.optim as optim
from make_data import train_dataloader, test_dataloader
from make_net import net, Net
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
import os


# os.environ["CUDA_VISIBLE_DEVICES"] ="1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

criterion = nn.MSELoss()
opt_Adam = optim.Adam(net.parameters(), lr=0.1, betas=(0.9, 0.99))
scheduler = ReduceLROnPlateau(opt_Adam, mode='min')

def train_model(model, criterion, optimizer, scheduler, num_epochs=30):
    since = time.time()

    train_loader =train_dataloader
    net = model()

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs")
        net = nn.DataParallel(net)
    net.to(device)

    criterion = criterion
    optimizer = optimizer
    scheduler = scheduler

    for epoch in range(num_epochs):
        running_loss = 0.0
        print("Epoch {}/{}".format(epoch, num_epochs-1))
        print("-" * 10)

        for i, sample in enumerate(train_loader, 0):
            image, pressure = sample['image'], sample['pressure']

            image = image.float()
            image = image.to(device)

            pressure = pressure.float()
            pressure = pressure.to(device)

            optimizer.zero_grad()
            output = net(pressure)
            loss = criterion(output, image)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if (i+1) % 100 == 0:
                print("%d, %5d, loss: %.3f" % (epoch, i, running_loss/100))
                running_loss = 0.0
        # scheduler.step()


train_model(model=Net, criterion=criterion, optimizer=opt_Adam, scheduler=scheduler)

I don’t know how to fix this error.
I would appreciate it if you could give me some valuable advice.

1 Like

I’m not sure, where you are initializing the net instance, as the code shouldn’t run in the first place.
(opt_Adam is using net.parameters() outside of train_model, which will fail, since net is initialized inside train_model)

This minimal example will run:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def train_model(model, criterion, num_epochs=30):
    
    net.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=0.1, betas=(0.9, 0.99))

    for epoch in range(num_epochs):
        running_loss = 0.0
        print("Epoch {}/{}".format(epoch, num_epochs-1))
        print("-" * 10)

        for i in range(1):
            image = torch.randn(1, 1)
            image = image.to(device)

            pressure = torch.randn(1, 1)
            pressure = pressure.to(device)

            optimizer.zero_grad()
            output = net(pressure)
            loss = criterion(output, image)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if (i+1) % 100 == 0:
                print("%d, %5d, loss: %.3f" % (epoch, i, running_loss/100))
                running_loss = 0.0
        



class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(1, 1)
        
    def forward(self, x):
        x = self.fc1(x)
        return x


net = Net()
train_model(model=net, criterion=nn.MSELoss())
1 Like
  • Many thanks for your prompt reply!
    However, after I modified the code for model training, the same error persisted.
    The error message is shown below…
Traceback (most recent call last):
  File "make_train.py", line 44, in <module>
    output = net(pressure)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply
    raise output
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker
    output = module(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/data/lyf/pytorch/make_net.py", line 92, in forward
    x = F.elu(nn.BatchNorm2d(num_features=self.f_dim*8)(x))
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/batchnorm.py", line 83, in forward
    exponential_average_factor, self.eps)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py", line 1697, in batch_norm
    training, momentum, eps, torch.backends.cudnn.enabled
RuntimeError: Tensor for argument #2 'weight' is on CPU, but expected it to be on GPU (while checking arguments for cudnn_batch_norm)

  • Please allow me to show you the code for the network definition.
import torch
import torch.nn as nn
import torch.nn.functional as F
from make_ops import conv_out_size_same
from make_data import batch_size


s_h, s_w = 403, 640
# 403,640
s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
# 202, 320
s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
# 101, 160
s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
# 51, 80
s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)
# 25, 40
s_h32, s_w32 = conv_out_size_same(s_h16, 2), conv_out_size_same(s_w16, 2)
# 12, 20
s_h64, s_w64 = conv_out_size_same(s_h32, 2), conv_out_size_same(s_w32, 2)
# 6, 10
s_h128, s_w128 = conv_out_size_same(s_h64, 2), conv_out_size_same(s_w64, 2)
# 3,5
s_h256, s_w256 = conv_out_size_same(s_h128, 2), conv_out_size_same(s_w128, 2)
# 2, 3


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        self.CONV1_DEPTH = 2
        self.CONV2_DEPTH = 4
        self.CONV3_DEPTH = 8
        self.CONV4_DEPTH = 16
        self.CONV5_DEPTH = 32
        self.CONV6_DEPTH = 64
        self.CONV7_DEPTH = 128
        self.CONV8_DEPTH = 256
        self.f_dim = 32
        self.channel_dim = 1
        self.FC_NODE = 512
        self.IMG_HEIGHT = 403
        self.IMG_WIDTH = 640
        self.batch_size = batch_size

        self.fc1 = nn.Linear(in_features=10, out_features=self.f_dim*8)
        self.fc2 = nn.Linear(in_features=self.f_dim*8, out_features=self.f_dim*8*s_w256*s_h256)

        self.deconv1 = nn.ConvTranspose2d(in_channels=self.f_dim*8, out_channels=self.f_dim*4,
                                          kernel_size=2, stride=2)
        self.deconv2 = nn.ConvTranspose2d(in_channels=self.f_dim*4, out_channels=self.f_dim*2,
                                          kernel_size=2, stride=2)
        self.deconv3 = nn.ConvTranspose2d(in_channels=self.f_dim*2, out_channels=self.f_dim,
                                          kernel_size=2, stride=2)
        self.deconv4 = nn.ConvTranspose2d(in_channels=self.f_dim, out_channels=self.f_dim//2,
                                          kernel_size=2, stride=2)
        self.deconv5 = nn.ConvTranspose2d(in_channels=self.f_dim//2, out_channels=self.f_dim//4,
                                          kernel_size=2, stride=2)
        self.deconv6 = nn.ConvTranspose2d(in_channels=self.f_dim//4, out_channels=self.f_dim//8,
                                          kernel_size=2, stride=2)
        self.deconv7 = nn.ConvTranspose2d(in_channels=self.f_dim//8, out_channels=self.f_dim//16,
                                          kernel_size=2, stride=2)
        self.deconv8 = nn.ConvTranspose2d(in_channels=self.f_dim//16, out_channels=self.channel_dim,
                                          kernel_size=2, stride=2)

        self.conv1 = nn.Conv2d(in_channels=self.channel_dim, out_channels=self.CONV1_DEPTH,
                               kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=self.CONV1_DEPTH, out_channels=self.CONV2_DEPTH,
                               kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(in_channels=self.CONV2_DEPTH, out_channels=self.CONV3_DEPTH,
                               kernel_size=2, stride=2)
        self.conv4 = nn.Conv2d(in_channels=self.CONV3_DEPTH, out_channels=self.CONV4_DEPTH,
                               kernel_size=2, stride=2)
        self.conv5 = nn.Conv2d(in_channels=self.CONV4_DEPTH, out_channels=self.CONV5_DEPTH,
                               kernel_size=2, stride=2)
        self.conv6 = nn.Conv2d(in_channels=self.CONV5_DEPTH, out_channels=self.CONV6_DEPTH,
                               kernel_size=2, stride=2)
        self.conv7 = nn.Conv2d(in_channels=self.CONV6_DEPTH, out_channels=self.CONV7_DEPTH,
                               kernel_size=2, stride=2)
        self.conv8 = nn.Conv2d(in_channels=self.CONV7_DEPTH, out_channels=self.CONV8_DEPTH,
                               kernel_size=2, stride=2)

        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=1)

    def forward(self, input_tensor):

        x = self.fc1(input_tensor)
        x = self.fc2(x)
        x = x.view(-1, self.f_dim*8, s_h256, s_w256)
        x = F.elu(nn.BatchNorm2d(num_features=self.f_dim*8)(x))

        x = F.elu(nn.BatchNorm2d(num_features=self.f_dim*4)(self.deconv1(x)))
        x = F.elu(nn.BatchNorm2d(num_features=self.f_dim*2)(self.deconv2(x)))
        x = F.elu(nn.BatchNorm2d(num_features=self.f_dim)(self.deconv3(x)))
        x = F.elu(nn.BatchNorm2d(num_features=self.f_dim//2)(self.deconv4(x)))
        x = F.elu(nn.BatchNorm2d(num_features=self.f_dim//4)(self.deconv5(x)))
        x = F.elu(nn.BatchNorm2d(num_features=self.f_dim//8)(self.deconv6(x)))
        x = F.elu(nn.BatchNorm2d(num_features=self.f_dim//16)(self.deconv7(x)))
        x = F.tanh(nn.BatchNorm2d(num_features=self.channel_dim)(self.deconv8(x)))

        x = F.elu(nn.BatchNorm2d(num_features=self.CONV1_DEPTH)(self.conv1(x)))
        x = F.elu(nn.BatchNorm2d(num_features=self.CONV2_DEPTH)(self.conv2(x)))
        x = F.elu(nn.BatchNorm2d(num_features=self.CONV3_DEPTH)(self.conv3(x)))
        x = F.elu(nn.BatchNorm2d(num_features=self.CONV4_DEPTH)(self.conv4(x)))
        x = F.elu(nn.BatchNorm2d(num_features=self.CONV5_DEPTH)(self.conv5(x)))
        x = F.elu(nn.BatchNorm2d(num_features=self.CONV6_DEPTH)(self.conv6(x)))
        x = F.elu(nn.BatchNorm2d(num_features=self.CONV7_DEPTH)(self.conv7(x)))
        x = F.elu(nn.BatchNorm2d(num_features=self.CONV8_DEPTH)(self.conv8(x)))

        x = self.avg_pool(x)
        x = x.view(-1, self.num_flat_features(x))
        x = F.elu(nn.Linear(in_features=x.size()[-1], out_features=self.FC_NODE)(x))
        x = nn.Linear(in_features=self.FC_NODE, out_features=self.IMG_HEIGHT*self.IMG_WIDTH)(x)
        x = x.view(-1, self.IMG_HEIGHT, self.IMG_WIDTH)

        return x

    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

The error message tells me Error in checking BatchNorm used in Cudnn.

  • The code for my modified model training is shown below.
import torch
import torch.nn as nn
import torch.optim as optim
from make_data import train_dataloader, test_dataloader
from make_net import Net


num_epochs = 30
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_loader = train_dataloader

criterion = nn.MSELoss()
net = Net()
optimizer = optim.Adam(net.parameters(), lr=0.1, betas=(0.9, 0.99))

if torch.cuda.is_available():
    print("Let's use", torch.cuda.device_count(), "GPUs")
    net = nn.DataParallel(net)
net.to(device)


for epoch in range(num_epochs):
    running_loss = 0.0
    print("Epoch {}/{}".format(epoch, num_epochs-1))
    print("-" * 10)

    for i, sample in enumerate(train_loader, 0):
        image, pressure = sample['image'], sample['pressure']

        image = image.float()
        image = image.to(device)

        pressure = pressure.float()
        pressure = pressure.to(device)

        optimizer.zero_grad()
        output = net(pressure)
        loss = criterion(output, image)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if (i+1) % 100 == 0:
            print("%d, %5d, loss: %.3f" % (epoch, i, running_loss/100))
            running_loss = 0.0


  • Many thanks for your prompt reply!
    Thank you very much!!!

Thanks for the code.
You are creating some layers inside the forward method on the fly.

These layers will be reinitialized in each forward pass, so they won’t be trained, and also won’t get pushed to the GPU, since they are unknown to the model when you call model.to(device).

The standard way would be to create these layers in the __init__ method (as your other layers) and just apply them on the activation in forward.

If you really want to use randomly initialized layers in the forward pass, you could use your code and push these newly created layers manually using the device attribute of another already registered parameter.

4 Likes
  • Thank you very much for your prompt reply and valuable comments.
  • I will modify my code according to your opinion.
  • Thank you again