Moving from PyTorch 0.4.1 to 1.0 changes model output

Hello,

I have a pretrained VGG-face descriptor that I downloaded from the official website and I am just doing inference on it. I export the weights from torch(lua) to PyTorch with the following code:

class VGG_16(nn.Module):
    """
    Main Class
    """

    def __init__(self):
        """
        Constructor
        """
        super().__init__()
        self.block_size = [2, 2, 3, 3, 3]
        self.conv_1_1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)
        self.conv_1_2 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
        self.conv_2_1 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
        self.conv_2_2 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
        self.conv_3_1 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
        self.conv_3_2 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
        self.conv_3_3 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
        self.conv_4_1 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
        self.conv_4_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
        self.conv_4_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
        self.conv_5_1 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
        self.conv_5_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
        self.conv_5_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
        self.fc6 = nn.Linear(512 * 7 * 7, 4096)
        self.fc7 = nn.Linear(4096, 4096)
        self.fc8 = nn.Linear(4096, 2622)

    def load_weights(self, path="/home/diego/Projects/dfw_benchmark/pretrained/VGG_FACE.t7"):
        """ Function to load luatorch pretrained

        Args:
            path: path for the luatorch pretrained
        """
        model = torchfile.load(path)
        counter = 1
        block = 1
        for i, layer in enumerate(model.modules):
            if layer.weight is not None:
                if block <= 5:
                    self_layer = getattr(self, "conv_%d_%d" % (block, counter))
                    counter += 1
                    if counter > self.block_size[block - 1]:
                        counter = 1
                        block += 1
                    self_layer.weight.data[...] = torch.tensor(layer.weight).view_as(self_layer.weight)[...]
                    self_layer.bias.data[...] = torch.tensor(layer.bias).view_as(self_layer.bias)[...]
                else:
                    self_layer = getattr(self, "fc%d" % (block))
                    block += 1
                    self_layer.weight.data[...] = torch.tensor(layer.weight).view_as(self_layer.weight)[...]
                    self_layer.bias.data[...] = torch.tensor(layer.bias).view_as(self_layer.bias)[...]

    def forward(self, x):
        """ Pytorch forward

        Args:
            x: input image (224x224)

        Returns: class logits

        """
        x = F.relu(self.conv_1_1(x))
        x = F.relu(self.conv_1_2(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_2_1(x))
        x = F.relu(self.conv_2_2(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_3_1(x))
        x = F.relu(self.conv_3_2(x))
        x = F.relu(self.conv_3_3(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_4_1(x))
        x = F.relu(self.conv_4_2(x))
        x = F.relu(self.conv_4_3(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_5_1(x))
        x = F.relu(self.conv_5_2(x))
        x = F.relu(self.conv_5_3(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc6(x))
        x = F.dropout(x, 0.5)
        x = F.relu(self.fc7(x))
        x = F.dropout(x, 0.5)
        return self.fc8(x)


if __name__ == "__main__":
    model = VGG_16()
    model.load_weights()
    im = cv2.imread("../ak.png")
    im = torch.Tensor(im).permute(2, 0, 1).view(1, 3, 224, 224)
    import numpy as np
    model.eval()
    im -= torch.Tensor(np.array([129.1863, 104.7624, 93.5940])).view(1, 3, 1, 1)

    preds = F.softmax(model(im), -1)
    print(preds)
    values, indices = preds.max(-1)
    print(indices)

The output is correct when using PyTorch 0.4.1:

[[1.7682e-05, 2.1315e-04, 9.6148e-01,  ..., 7.4219e-08, 1.3229e-06,
         4.7004e-07]],

But upgrading to PyTorch 1.0 or 1.0post2 yields:

[[7.2586e-05, 5.0534e-05, 9.7303e-01, ..., 5.4698e-09, 2.3558e-07, 1.4000e-07]

Why is this happening? Thanks in advance!!

1 Like

Is the difference larger than what can be expected from numerical accuracy? I’d venture that some optimisation (maybe softmax?) Changed the calculation order. To test this, convert model and inputs to double and see if the difference goes down significantly.

Best regards

Thomas

Hello Thomas, thanks for your help!

I tried what you proposed and converted the model and the input to doubles, I also removed the softmax operation to help narrow down the problem. So now the model output on PyTorch 1.0.1.post2 is:

tensor([[ 3.0757,  5.4147, 15.5268,  ..., -3.8933,  1.6677, -0.2514]],
       dtype=torch.float64, grad_fn=<AddmmBackward>)

And on PyTorch 0.4.1:

tensor([[ 2.7830,  5.2725, 13.6867,  ..., -2.6902,  0.1903, -0.8445]],
       dtype=torch.float64, grad_fn=<ThAddmmBackward>)

Also I noticed that the model’s output when running on PyTorch 1.0.1.post2 seems to be non-deterministic:

Run 1->

 tensor([[ 3.0757,  5.4147, 15.5268,  ..., -3.8933,  1.6677, -0.2514]],
       dtype=torch.float64, grad_fn=<AddmmBackward>)

Run 2->

tensor([[ 3.2392,  7.3212, 15.5781,  ..., -3.5685,  0.5535,  0.7085]],
       dtype=torch.float64, grad_fn=<AddmmBackward>)

Run 3->

tensor([[ 5.2961,  6.7340, 22.7242,  ..., -4.9779,  0.6964, -1.0704]],
       dtype=torch.float64, grad_fn=<AddmmBackward>)

As you can see the differences are not at all small. What on earth is happening?
Could it be that F.dropout() is not turned off by model.eval() on Pytorch 1.0??

Thanks in advance!

If anyone else is having a simillar issue I solved it by changing the dropout layers from: F.dropout(x, 0.5) to F.dropout(x, 0.5, self.training) in the above code. The problem is now solved but I still don’t know why the bahviour of F.dropout() changes from 0.4.1 to 1.0.

1 Like

https://pytorch.org/docs/0.4.1/nn.html#id46 shows training=False as the default, while https://pytorch.org/docs/stable/nn.html#id47 defaults to training=True.

I don’t know why the default value changed.

1 Like