Big accuracy differences between PyTorch 1.13.1 and 2.0.1!

Hello,

we reported this performance issue between the two PyTorch version as #105837 GitHub issue.

As a new user cannot seem to add more links or media, so all the info should be found in the GitHub issue.

Could someone please take a look?

Thank you in advance!

Did you check if any parameter updates are performed or if the model is static?

Thanks @ptrblck for your answer! Checked with the following code whether the parameters get updated and they seem to be not, as the original plots submitted also show that no learning is happening.

optim.zero_grad()

params_before = list(model.parameters())

loss.backward()
optim.step()

params_after = list(model.parameters())

all_params_equal = True
for p_before, p_after in zip(params_before, params_after):
    all_params_equal = all_params_equal and torch.equal(p_before.data, p_after.data)

print(f"Are all parameters equal? {all_params_equal}")

Not sure what you mean by checking whether the model is static.

Thanks again!

Thanks for the check! Could you check the .grad attributes of all parameters before and after the first backward() call next?
They should be set to None before and should show a valid tensor afterwards.
If both print statements are showing None gradients, the computation graph seems to be detached and we would need to check why that’s the case.

Thanks again @ptrblck . Modified the training loop as below:

optim.zero_grad()

params_before = list([param.data.cpu().detach().numpy() for param in model.parameters()])

for p_before in model.parameters():
    print(f"Grad: before - {p_before.grad}")

loss.backward()

for p_after in model.parameters():
    print(f"Grad: after - {p_after.grad}")

optim.step()

params_after = list([param.data.cpu().detach().numpy() for param in model.parameters()])

all_params_equal = True
for p_before, p_after in zip(params_before, params_after):
    all_params_equal = all_params_equal and np.array_equal(p_before, p_after)

print(f"Are all parameters equal? {all_params_equal}")

Previously I think the parameters shared the same reference, and were marked as equal. Detached them and stored them as numpy arrays to make the comparison and there seem to be small changes in the parameters, i.e. they seem to get updated. Sorry for this!

The gradients seem to be None first and then after doing first the backprop they seem to get updated to tensors. Here is the output:

Grad: before - None
Grad: before - None
Grad: before - None
Grad: before - None
Grad: before - None
Grad: before - None
Grad: before - None
Grad: before - None
Grad: after - tensor([[ 1.4627e-03,  1.4627e-03,  1.4627e-03,  ...,  1.4627e-03,
          2.0679e-03,  1.3948e-03],
        [-1.3086e-04, -1.3086e-04, -1.3086e-04,  ..., -1.3086e-04,
         -3.8431e-04, -5.9046e-06],
        [-6.0961e-03, -6.0961e-03, -6.0961e-03,  ..., -6.0961e-03,
         -9.0613e-04, -4.6752e-03],
        ...,
        [-1.9681e-02, -1.9681e-02, -1.9681e-02,  ..., -1.9681e-02,
          1.2326e-02,  1.6755e-03],
        [ 1.2912e-03,  1.2912e-03,  1.2912e-03,  ...,  1.2912e-03,
         -1.2034e-03,  2.9203e-03],
        [-1.9582e-02, -1.9582e-02, -1.9582e-02,  ..., -1.9582e-02,
          4.8703e-04, -2.1845e-03]], device='mps:0')
Grad: after - tensor([[-1.8982e-03, -6.2929e-04, -1.5961e-03,  ..., -8.6900e-04,
         -1.0032e-03,  3.9843e-04],
        [ 1.8086e-04,  2.5083e-04, -1.6416e-04,  ...,  9.0401e-05,
         -1.1745e-04,  1.5602e-04],
        [-1.5279e-03, -5.0708e-03,  2.8829e-03,  ..., -2.1105e-03,
          3.0370e-03, -4.7238e-03],
        ...,
        [-2.6152e-02, -3.1857e-02,  1.3831e-02,  ..., -1.3835e-02,
          1.1094e-02, -1.8385e-02],
        [ 1.2020e-03,  2.4224e-03,  6.7051e-04,  ...,  1.7220e-03,
         -5.1741e-04,  2.1676e-03],
        [-1.4955e-02, -1.9936e-02,  8.3015e-03,  ..., -8.8695e-03,
          7.3607e-03, -1.2555e-02]], device='mps:0')
Grad: after - tensor([ 1.4685e-03, -1.3138e-04, -6.1201e-03, -4.5789e-02, -7.9349e-03,
        -1.6080e-03,  1.7505e-02,  2.7263e-02,  5.0776e-04,  3.1822e-03,
         2.4568e-02, -1.1433e-02,  3.9419e-03,  1.5574e-02,  6.2263e-04,
        -1.2129e-02, -5.4287e-02, -1.0878e-04, -1.1254e-02,  1.0963e-02,
         3.1155e-02, -5.8268e-03,  2.4303e-02,  4.3852e-03, -1.9106e-02,
        -1.5162e-02, -6.5663e-02,  8.1562e-02,  2.1123e-02, -3.0827e-03,
         1.7691e-02,  1.1239e-03, -3.1183e-03, -1.0129e-02,  1.1770e-03,
        -9.5528e-04, -7.8951e-04,  5.4114e-02,  5.2791e-05,  2.3061e-05,
         1.3466e-02,  1.1075e-03,  6.7796e-04,  1.1105e-01, -1.3704e-02,
         1.6770e-02, -2.0077e-03,  3.3308e-05,  6.1390e-02, -1.2727e-02,
        -5.9857e-02,  7.5183e-03, -2.3664e-02, -1.4469e-02,  3.3819e-02,
         3.1799e-02,  4.5254e-03, -2.6284e-03,  6.6589e-04, -1.3513e-03,
        -7.4101e-03, -2.9122e-03,  2.5119e-03, -5.5881e-03,  3.6866e-03,
         4.8116e-02,  4.3276e-03, -1.4491e-02, -6.1160e-02,  1.9650e-03,
        -2.2383e-02, -1.4282e-03,  6.4808e-04,  2.0432e-02, -6.1903e-02,
         4.2811e-02,  2.0047e-02, -4.6595e-04,  4.8694e-04,  4.7563e-02,
         6.9717e-02, -3.4239e-02, -3.8569e-02, -2.3586e-04, -1.8914e-02,
        -3.5095e-03, -2.3049e-03, -3.1967e-03,  1.2116e-02, -3.9122e-03,
        -2.7179e-03, -2.0984e-02,  1.8846e-02, -1.6740e-02, -9.4212e-03,
        -5.8745e-04, -2.2158e-02,  1.4861e-02, -3.1859e-02, -3.9706e-03,
        -4.2073e-02, -3.3877e-02, -3.7556e-02,  5.8694e-04,  2.3499e-02,
         6.8007e-02, -2.4977e-03, -2.4801e-02, -2.3341e-02, -7.4005e-03,
         7.1274e-04, -5.4910e-03,  1.8000e-03, -2.8609e-03, -7.8491e-03,
        -1.7710e-02,  5.1535e-03, -6.2510e-02, -1.0905e-03, -2.0986e-02,
        -2.9979e-02, -1.9756e-03,  1.6864e-02,  5.0562e-03,  1.0459e-03,
        -1.9758e-02,  1.2963e-03, -1.9659e-02], device='mps:0')
Grad: after - tensor([ 1.4685e-03, -1.3138e-04, -6.1201e-03, -4.5789e-02, -7.9349e-03,
        -1.6080e-03,  1.7505e-02,  2.7263e-02,  5.0776e-04,  3.1822e-03,
         2.4568e-02, -1.1433e-02,  3.9419e-03,  1.5574e-02,  6.2263e-04,
        -1.2129e-02, -5.4287e-02, -1.0878e-04, -1.1254e-02,  1.0963e-02,
         3.1155e-02, -5.8268e-03,  2.4303e-02,  4.3852e-03, -1.9106e-02,
        -1.5162e-02, -6.5663e-02,  8.1562e-02,  2.1123e-02, -3.0827e-03,
         1.7691e-02,  1.1239e-03, -3.1183e-03, -1.0129e-02,  1.1770e-03,
        -9.5528e-04, -7.8951e-04,  5.4114e-02,  5.2791e-05,  2.3061e-05,
         1.3466e-02,  1.1075e-03,  6.7796e-04,  1.1105e-01, -1.3704e-02,
         1.6770e-02, -2.0077e-03,  3.3308e-05,  6.1390e-02, -1.2727e-02,
        -5.9857e-02,  7.5183e-03, -2.3664e-02, -1.4469e-02,  3.3819e-02,
         3.1799e-02,  4.5254e-03, -2.6284e-03,  6.6589e-04, -1.3513e-03,
        -7.4101e-03, -2.9122e-03,  2.5119e-03, -5.5881e-03,  3.6866e-03,
         4.8116e-02,  4.3276e-03, -1.4491e-02, -6.1160e-02,  1.9650e-03,
        -2.2383e-02, -1.4282e-03,  6.4808e-04,  2.0432e-02, -6.1903e-02,
         4.2811e-02,  2.0047e-02, -4.6595e-04,  4.8694e-04,  4.7563e-02,
         6.9717e-02, -3.4239e-02, -3.8569e-02, -2.3586e-04, -1.8914e-02,
        -3.5095e-03, -2.3049e-03, -3.1967e-03,  1.2116e-02, -3.9122e-03,
        -2.7179e-03, -2.0984e-02,  1.8846e-02, -1.6740e-02, -9.4212e-03,
        -5.8745e-04, -2.2158e-02,  1.4861e-02, -3.1859e-02, -3.9706e-03,
        -4.2073e-02, -3.3877e-02, -3.7556e-02,  5.8694e-04,  2.3499e-02,
         6.8007e-02, -2.4977e-03, -2.4801e-02, -2.3341e-02, -7.4005e-03,
         7.1274e-04, -5.4910e-03,  1.8000e-03, -2.8609e-03, -7.8491e-03,
        -1.7710e-02,  5.1535e-03, -6.2510e-02, -1.0905e-03, -2.0986e-02,
        -2.9979e-02, -1.9756e-03,  1.6864e-02,  5.0562e-03,  1.0459e-03,
        -1.9758e-02,  1.2963e-03, -1.9659e-02], device='mps:0')
Grad: after - tensor([[ 2.4017e-01,  3.4895e-01, -2.4161e-01,  1.6617e-01,  1.4000e-02,
         -2.8061e-01, -2.4008e-01,  1.9611e-01,  2.1064e-01,  3.3362e-01,
         -2.3464e-01, -3.1525e-01,  3.0090e-01, -1.4328e-01,  3.2811e-01,
         -2.5250e-01,  1.2876e-03, -3.4962e-01, -3.2906e-01,  2.3347e-01,
         -2.5695e-01, -3.1041e-01,  9.6729e-02, -3.3356e-01, -1.7707e-01,
          3.0241e-01,  3.8306e-02, -1.7197e-01,  7.4984e-02,  1.7992e-01,
          3.1127e-01,  3.4695e-01, -3.4585e-01, -2.8037e-01,  3.4493e-01,
         -3.1164e-01, -3.0445e-01,  3.9136e-02,  3.4769e-01, -3.5068e-01,
         -2.8867e-01,  3.0238e-01,  3.4511e-01, -6.2755e-02, -2.9789e-01,
         -2.5160e-01, -3.4609e-01,  3.5162e-01, -1.4761e-01, -2.9846e-01,
         -2.2022e-01,  1.1042e-01,  2.8762e-01,  3.0170e-01, -2.2555e-01,
          2.2018e-01, -3.3848e-01,  3.5028e-01, -3.4993e-01, -2.3353e-01,
         -3.2610e-01, -3.3896e-01,  3.4588e-01,  3.3694e-01,  3.4409e-01,
          2.5323e-01,  3.2930e-01,  2.4329e-01,  2.3358e-01, -7.9981e-02,
          2.9703e-01,  3.4929e-01,  3.5171e-01, -3.1841e-01,  1.9108e-01,
         -2.4899e-01,  1.6254e-01,  3.5132e-01, -3.4985e-01,  1.5251e-01,
          2.1862e-02,  2.6289e-01,  1.4750e-01,  2.4689e-01,  3.2552e-01,
         -3.2441e-01, -3.4585e-01,  3.3942e-01, -1.2287e-01,  3.3929e-01,
         -3.4600e-01, -2.5490e-01, -1.4810e-01, -2.4914e-01, -3.1616e-01,
         -2.7996e-01,  5.3910e-02,  8.3362e-02, -2.2756e-01, -3.4595e-01,
          1.4570e-01,  2.2476e-01,  1.0500e-01, -3.5072e-01, -1.3940e-01,
         -1.2590e-01, -3.2186e-01, -2.0266e-01,  3.0027e-01, -5.6623e-02,
         -2.3789e-01,  3.3488e-01,  3.4398e-01, -3.1896e-01,  3.0663e-01,
          2.8664e-01,  3.3766e-01,  6.4165e-02, -3.1826e-01,  2.6319e-01,
         -2.2919e-01,  3.5134e-01,  3.0663e-01, -1.3062e-01, -2.9913e-01,
          1.3565e-01, -1.6022e-01,  2.2197e-01],
        [-2.1554e-02, -1.5210e-02,  3.7830e-02,  2.6736e-02, -2.1166e-02,
         -1.9468e-03,  8.0254e-02, -2.2102e-03,  2.5859e-03, -1.2345e-02,
          2.7264e-02,  1.7788e-02, -1.5296e-02,  1.9181e-03, -1.3894e-02,
          3.2927e-02,  8.5312e-03,  1.5223e-02,  2.1626e-02, -1.0207e-02,
          3.5204e-02,  2.5883e-02,  2.4556e-02,  1.8632e-02,  6.5476e-05,
         -1.6241e-02,  7.4333e-03,  5.2232e-02, -1.6711e-02,  3.2708e-02,
         -1.0692e-02, -1.4305e-02,  1.3133e-02,  2.7117e-02, -1.4753e-02,
          4.8611e-02,  1.4527e-02,  1.8879e-02, -1.4481e-02,  1.4897e-02,
          1.1238e-02, -1.2450e-02, -1.6202e-02,  4.9784e-03,  1.2289e-02,
          1.9180e-02,  1.4027e-02, -1.4833e-02,  1.6427e-02,  1.0015e-02,
          5.9121e-02,  5.6994e-02, -3.0438e-02, -1.1613e-02,  1.1159e-02,
         -4.3374e-02,  2.7599e-02, -2.2615e-02,  1.4398e-02,  3.0019e-02,
          1.0136e-02,  1.9818e-02, -1.4665e-02, -2.6874e-02, -1.8996e-02,
         -9.1702e-03, -1.3508e-02, -4.1096e-02, -2.9625e-02,  5.2459e-02,
         -4.2368e-02, -1.4232e-02, -1.5201e-02,  4.9654e-02, -1.1673e-02,
          5.3639e-02,  1.8897e-02, -1.4671e-02,  1.5038e-02,  2.0248e-02,
          2.6686e-02, -3.5407e-02, -1.6679e-02, -3.0283e-02, -2.9321e-02,
          1.8608e-02,  1.5144e-02, -1.2829e-02,  8.7272e-03, -1.8822e-02,
          1.6827e-02,  3.6772e-03,  5.2391e-03,  1.7730e-02,  9.1419e-03,
          6.0197e-02, -5.2879e-03,  1.0835e-02, -1.6823e-03,  1.7445e-02,
         -2.1828e-02, -2.4577e-02, -2.9082e-02,  1.5298e-02,  4.3386e-02,
         -4.2322e-04,  6.9954e-02, -9.6678e-03, -3.8367e-02,  1.6948e-02,
          3.5717e-02, -1.7567e-02, -3.4780e-02,  1.2462e-02, -1.0854e-02,
         -1.8944e-02, -2.4447e-02,  1.1157e-02,  1.2722e-02, -3.6463e-02,
          1.8334e-02, -1.6749e-02, -6.1643e-02,  6.6864e-03,  3.9645e-03,
         -4.8894e-03,  4.2964e-03,  2.5981e-03]], device='mps:0')
Grad: after - tensor([[ 0.1867, -0.2881],
        [-0.0076,  0.0535]], device='mps:0')
Grad: after - tensor([ 0.3526, -0.0147], device='mps:0')
Grad: after - tensor([ 0.3526, -0.0147], device='mps:0')
Are all parameters equal? False

Thanks again for your help!

OK, so it seems the gradients are calculated and the parameters updated.
It’s still unclear why your model isn’t learning anything. Could you compare the gradient magnitudes between PyTorch 1.13.1 and 2.0.1?

@ptrblck so the L2 distance between the gradients of the two versions, after the first time the backward() gets called is around 0.45.

I’m unsure what norm values would be expected, but I also see that you are using the mps backend.
Do you see the same behavior in any other backend (e.g. CPU or CUDA)?

@ptrblck yes, the same behavior on Nvidia and CPU devices.

Thank for checking. In this case I would start removing parts of the model until e.g. a single layer is used to check if you can train it at all. I haven’t seen this kind of issue before and since your parameters are updated, I guess something else might block the training.

@ptrblck so the model has just two RNN layers.
Switch to just one layer and we get the same bad results in PyTorch 2.0.1 while in 1.13.1 the performance goes from ~0.9 to ~0.75.

In that case I would assume it should be easy to post a minimal and executable code snippet reproducing the training stagnation, which we could use to reproduce and debug it.

@ptrblck absolutely. It is available here.

The sample contains various conda environment files depending on your hardware (Mac/CPU and Nvidia GPU) for each of the PyTorch version 1.13.1 and 2.0.1.

To run it on CPU the following command should be executed: python main.py --num_shot 1 --num_classes 2.

Thanks again for all your help so far @ptrblck !

Thanks! Unfortunately, your code isn’t executable without downloading an unknown dataset from an unknown source.
The comments in your model also don’t seem to be valid as using any value combination for [B, K+1, N, 784] for the data and [B, K+1, N, N] results in shape mismatches, so could you let me know what the expected shapes are?

@ptrblck we are using the Omniglot dataset, i.e. the transpose of MNIST, as per the MANN paper we are implementing in the code sample. If this poses a security threat for you, could switch to MNIST, though it will take some time.

The code should definitely run as it is, so not sure what you mean with the comments related to the shapes. Have you executed the code and received any errors?

Not sure if you already spotted the issue.
Thanks for providing the code to reproduce.

I used Windows 11 to run all the combinations,

  • Pytorch 1.13 + cuda (environment_113_cuda.yml)
  • Pytorch 2.0.1 + cuda (environment_201_cuda.yml),
  • Pytorch 2.1.0 + cuda (pytorch_nightly)

I faced an image normalization issue due to imageio version mismatch (environment_113_cuda = imageio 2.19.3, environment_201_cuda = imageio 2.31.1).

In your code, the normalization is done as follows:

        image = imageio.imread(filename)
        image = image.reshape([dim_input])
        image = image.astype(np.float32) / 255.0

In imageio 2.19.3, the min and max values are [0, 255].
However in imageio 2.31.1, the min and max values are [0., 1.]. Hence (image / 255.) makes the image to (almost) 0 in the pytorch 2.0.1 conda environment.

After handling this normalization correctly, all the pytorch versions give similar performance (atleast in my experiments).

        image = image / image.max()

I am not sure if this is the same issue that you are facing though.

2 Likes

@InnovArul thanks a lot! That was it!