Output Discrepancy Between CPU and GPU Execution of Custom PyTorch Model

I’ve encountered an issue with my custom PyTorch model that utilizes torch.nn.Conv2d . When I move the model from CPU to GPU, the output of the model changes, even though the input remains fixed. Here’s a simplified version of my code:

import torch
import torch.nn as nn


class MyModel(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        depth,
        kernel_size=3,
        stride=1,
        padding=1,
        bias=True,
        device=None,
    ):
        super(MyModel, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.depth = depth
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.device = device

        # Convolutional layers
        self.pointwise = nn.Conv2d(
            in_channels,
            out_channels * depth,
            1,
            stride=1,
            padding=0,
            bias=False,
            device=device,
        )
        self.depthwise = nn.Conv2d(
            out_channels * depth,
            out_channels * depth,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=out_channels * depth,
            bias=False,
            device=device,
        )

        # Bias
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels, device=device))
        else:
            self.bias = None

    def forward(self, input):

        # Convolution
        output = self.pointwise(input)
        output = self.depthwise(output)

        # Reshape and sum
        output = output.view(
            output.shape[0], self.out_channels, self.depth, *output.shape[2:]
        )
        output = torch.sum(output, dim=2)

        # Add bias if applicable
        if self.bias is not None:
            output += self.bias.view(1, -1, 1, 1)

        return output


if __name__ == "__main__":
    model_cpu = MyModel(in_channels=12, out_channels=2, depth=9, kernel_size=3)
    print(f"model_cpu {model_cpu}")
    input_cpu = torch.randn((2, 12, 4, 4))
    output_cpu = model_cpu(input_cpu)
    print(f"output_cpu {output_cpu}")

    model_gpu = model_cpu.to('cuda')
    print(f"model_gpu {model_gpu}")
    input_gpu = input_cpu.to('cuda')
    output_gpu = model_gpu(input_gpu)
    print(f"output_gpu {output_gpu}")

Here’s the output:

(base) van-tien.pham@machine:~/SVD$ python test.py 
model_cpu MyModel(
  (pointwise): Conv2d(12, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (depthwise): Conv2d(18, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=18, bias=False)
)
output_cpu tensor([[[[ 0.0352,  0.8974,  0.4504, -0.5259],
          [ 0.3416,  0.9579,  0.2390, -0.9380],
          [ 1.6233, -1.4409,  0.8132,  0.5683],
          [-0.2270, -0.1542, -1.4981,  0.3713]],

         [[ 0.2146,  0.5099,  0.1943,  0.2672],
          [ 0.8518,  1.0509, -0.7429,  1.4174],
          [-0.6285, -1.2005, -0.7633, -1.0010],
          [-0.1008,  0.3585, -1.1639, -0.1421]]],


        [[[ 0.6001,  1.3276,  0.3367, -0.6931],
          [ 1.5325,  0.8300,  0.1708, -1.2455],
          [-1.2145,  0.5081,  1.0613,  0.9404],
          [ 0.4986,  0.5428,  0.9134,  1.3834]],

         [[ 0.0309, -0.4950,  0.6357,  0.2674],
          [ 0.6251,  0.1574,  0.3819, -0.3022],
          [ 1.4431,  0.4977, -1.2133,  0.1650],
          [ 0.0207,  0.0846,  0.5615,  0.4055]]]], grad_fn=<AddBackward0>)
model_gpu MyModel(
  (pointwise): Conv2d(12, 18, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (depthwise): Conv2d(18, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=18, bias=False)
)
output_gpu tensor([[[[ 0.0351,  0.8976,  0.4507, -0.5258],
          [ 0.3413,  0.9577,  0.2393, -0.9379],
          [ 1.6232, -1.4403,  0.8125,  0.5687],
          [-0.2269, -0.1542, -1.4982,  0.3713]],

         [[ 0.2145,  0.5097,  0.1945,  0.2672],
          [ 0.8518,  1.0509, -0.7430,  1.4173],
          [-0.6284, -1.2005, -0.7634, -1.0011],
          [-0.1005,  0.3584, -1.1636, -0.1423]]],


        [[[ 0.5998,  1.3273,  0.3366, -0.6932],
          [ 1.5322,  0.8298,  0.1706, -1.2456],
          [-1.2146,  0.5079,  1.0608,  0.9404],
          [ 0.4988,  0.5421,  0.9135,  1.3837]],

         [[ 0.0308, -0.4950,  0.6357,  0.2676],
          [ 0.6255,  0.1573,  0.3818, -0.3021],
          [ 1.4432,  0.4981, -1.2132,  0.1653],
          [ 0.0203,  0.0845,  0.5616,  0.4055]]]], device='cuda:0',
       grad_fn=<AddBackward0>)

The output output_cpu and output_gpu are different, although the input remains the same. Strangely, if I use a batch size of 1, the outputs are identical. My custom model consists of two sequential nn.Conv2d layers, a view operation, a torch.sum, and an addition operation in the forward method.

I’ve tried both .cuda() and .to('cuda') for moving the model and input to the GPU, but the results remain inconsistent. Could you please help me identify the root cause of this issue? Is there a mistake in my model implementation, or is it related to data movement between devices?

I have examined this related question, but unfortunately, the information provided there does not seem to address the specific issues I’m encountering in my case.

Thank you for your assistance!

I cannot reproduce this problem with my install, but if it comes to reproducibility you should take a look at Reproducibility — PyTorch 2.2 documentation
While I recommend reading that docs page,torch.backends.cudnn.deterministic = True may be solution to your problem. It makes conv operations deterministic. Also torch.backends.cudnn.benchmark = False should make operations consistent between runs.

1 Like

Small numerical mismatches are expected as the order of operation might differ on different platforms.
Depending on the scale of the numerical mismatches and the used GPU you might be using TF32 on your GPU for performance reasons. More details can be found in the docs.

2 Likes