Problem when moving model to cuda

I’m implementing ResNet50 with PyTorch. It ran perfectly okay on cpu, but when I moved the model to GPU, I encountered this error

RuntimeError                              Traceback (most recent call last)

<ipython-input-37-160652fffff7> in <module>()
      1 input = torch.rand(32,3,224,224).to(device)
----> 2 output = model(input)

12 frames

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
    348                             _pair(0), self.dilation, self.groups)
    349         return F.conv2d(input, weight, self.bias, self.stride,
--> 350                         self.padding, self.dilation, self.groups)
    351 
    352     def forward(self, input):

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

Here is my related code

class BottleNeck(nn.Module):
    def __init__(self,base_kernel,input_kernel_depth,stride=1):
        super().__init__()
        self.base_kernel = base_kernel
        self.input_kernel_depth = input_kernel_depth
        self.stride = stride 
        
        self.conv1 = nn.Conv2d(in_channels=input_kernel_depth,out_channels=base_kernel,kernel_size=1,stride=stride) 
        self.bn1 = nn.BatchNorm2d(base_kernel)
        self.conv2 = nn.Conv2d(in_channels=base_kernel,out_channels=base_kernel,kernel_size=3,padding=1)
        self.bn2 = nn.BatchNorm2d(base_kernel)
        self.conv3 = nn.Conv2d(in_channels=base_kernel,out_channels=base_kernel*4,kernel_size=1)
        self.bn3 = nn.BatchNorm2d(base_kernel*4)
        
    def forward(self,input):
        identity = input
        
        output = self.conv1(input)
        output = self.conv2(output)
        output = self.conv3(output)
        
        upsampler = self.upsampling(self.input_kernel_depth,self.base_kernel*4,self.stride)
        identity = upsampler(identity)
        return F.relu(output + identity)
    
    def upsampling(self,in_channels,out_channels,stride=1):
        upsampler = nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=stride)
        return upsampler

class Block(nn.Module):
    def __init__(self,input_kernel_depth,stack_count,base_kernel,stride,padding):
        super().__init__()
        layers = []
        layers.append(
            BottleNeck(base_kernel,input_kernel_depth,stride=stride)
        )
        for i in range(1,stack_count):
            layers.append(
                BottleNeck(base_kernel,base_kernel*4,stride=1)
            )
        self.block = nn.Sequential(*layers)
    
    def forward(self,input):
        output = self.block(input)
        return output
    
class ResNet50(nn.Module):
    def __init__(self,n_classes):
        super().__init__()
        self.n_classes = n_classes
        
        self.stage1 = nn.Conv2d(3,64,7,2,3)
        self.stage2 = nn.Sequential(
            nn.MaxPool2d(3,2,1).to(device),
            Block(input_kernel_depth=64,stack_count=3,base_kernel=64,stride=1,padding=0)
        )
        self.stage3 = nn.Sequential(
            Block(input_kernel_depth=256,stack_count=4,base_kernel=128,stride=2,padding=1)
        )
        self.stage4 = nn.Sequential(
            Block(input_kernel_depth=512,stack_count=6,base_kernel=256,stride=2,padding=1)
        )
        self.stage5 = nn.Sequential(
            Block(input_kernel_depth=1024,stack_count=3,base_kernel=512,stride=2,padding=1)
        )
        
        self.avg_pool = nn.AvgPool2d(kernel_size=7)
        
        self.fc1 = nn.Linear(2048,1000)
        self.fc2 = nn.Linear(1000,n_classes)

    def forward(self,input):
        batch_size = input.size(0)
        
#         print("Global input",input.shape)
        stage1_output = self.stage1(input)
#         print("Stage 1",stage1_output.shape)
        stage2_output = self.stage2(stage1_output)
#         print("Stage 2",stage2_output.shape)
        stage3_output = self.stage3(stage2_output)
#         print("Stage 3",stage3_output.shape)
        stage4_output = self.stage4(stage3_output)
#         print("Stage 4",stage4_output.shape)
        stage5_output = self.stage5(stage4_output)
#         print("Stage 5",stage5_output.shape)
        
        output = self.avg_pool(stage5_output).view(batch_size,-1)
        output = F.relu(self.fc1(output))
        output = F.log_softmax(self.fc2(output),dim=1)
        return output

model = ResNet50(4)
model.to(device)

optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
criterion = nn.NLLLoss()

The error above occured when I tried to run some test input

input = torch.rand(32,3,224,224).to(device)
output = model(input)

Based on the error message it seems that some parameters of the model weren’t pushed to the GPU.
In particular creating a new module in the forward will not push it directly to the GPU, if you don’t specify it:

upsampler = self.upsampling(self.input_kernel_depth,self.base_kernel*4,self.stride)

You could use upsampler = self.upsampling(self.input_kernel_depth,self.base_kernel*4,self.stride).to(identity.device) to solve this issue.
Note however, that you are re-initializing this module with random shape. I’m not familiar with your use case, but this module will not be trained.

1 Like

Thank you for the reply. It is indeed the solution to my case.

However, I do want this module to be trained. What is the proper way to do that?

You would have to initialize this module in the __init__ method of your module (as the other layers) and just call it in forward.

@ptrblck Could you inspect some similar code here? https://github.com/NathanUA/U-2-Net/blob/b77cd6da3204efcb03e18e15dd3b9eb24d47f969/model/u2net.py#L24

_upsample_like is called only in the forward pass just like OP’s self.upsampling

I wonder if this is the cause of my libtorch module.forward() not working even though my python jit trace saves and imports in libtorch without errors.

If you think it is the cause, what is a concise fix considering _upsample_like is written a lot in all the forward implementations?

My other thread is here Debugging runtime error module->forward(inputs) libtorch 1.4

Your code looks good, since _upsample_like onle uses the functional API via F.upsample without any parameters.
You could of course create an nn.Upsample module, but this wouldn’t make a difference and I think your current code is clear and shouldn’t be changed. I’ll take a look at the linked issue.

1 Like