How to convert to torch.tensor from numpy without losing precision?

When I use torch.from_numpy I lose precision.
How can I solve the issue?
It’s affecting the model performance significantly.

Hi,

What is the data type of your numpy array?
Can you give an example please?

Before from numpy.

array([[8.56253115e-01, 2.10600840e-02, 1.00000000e+00, 1.00000000e+00,
        0.00000000e+00, 1.00000000e+00, 7.10329978e-01, 1.00000000e+00,
        4.66952735e-01, 6.89281166e-01, 4.96416147e-01, 1.00000000e+00,
        5.56949372e-04, 1.85722837e-01, 1.00000000e+00, 1.00000000e+00,
        0.00000000e+00, 9.56788124e-01, 1.00000000e+00, 2.95043853e-01,
        1.00000000e+00, 1.50440394e-01, 1.64012738e-03, 1.00000000e+00,
        5.40330540e-01, 0.00000000e+00, 1.00000000e+00, 0.00000000e+00],
       [1.63444944e-02, 3.25443490e-04, 1.89391396e-03, 0.00000000e+00,
        2.64070302e-01, 2.12084330e-02, 1.88466696e-03, 5.66846773e-05,
        3.61758789e-04, 1.16182895e-01, 7.51278105e-02, 2.63272193e-01,
        0.00000000e+00, 3.20013679e-05, 1.00000000e+00, 2.04047972e-09,
        5.92447090e-02, 9.22426677e-01, 3.19583818e-01, 2.35541376e-02,
        1.02123520e-03, 0.00000000e+00, 5.90462369e-05, 3.88561546e-01,
        1.35224173e-01, 0.00000000e+00, 1.41158704e-04, 0.00000000e+00],
       [4.22830837e-01, 4.06760181e-03, 4.62998442e-02, 1.09968540e-03,
        3.96297000e-02, 0.00000000e+00, 1.78848070e-01, 2.06536186e-01,
        1.00919146e-02, 1.68430658e-02, 5.48692875e-02, 1.81380361e-02,
        1.21907962e-03, 2.46930120e-04, 1.00000000e+00, 0.00000000e+00,
        2.26129147e-04, 9.22508457e-01, 1.25277173e-02, 8.64714525e-02,
        3.67948620e-02, 3.40576792e-01, 2.05638313e-02, 7.67841625e-02,
        2.80965836e-02, 1.11051388e-02, 9.35414430e-02, 0.00000000e+00],
       [0.00000000e+00, 0.00000000e+00, 9.11751211e-05, 0.00000000e+00,
        2.15352661e-02, 0.00000000e+00, 9.50547046e-06, 8.84862763e-08,
        0.00000000e+00, 4.04984295e-03, 8.82974618e-03, 0.00000000e+00,
        4.44519766e-09, 1.21600648e-07, 1.00000000e+00, 0.00000000e+00,
        1.97038222e-01, 9.22414541e-01, 2.19747690e-02, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 2.21929856e-02,
        0.00000000e+00, 0.00000000e+00, 1.70121192e-05, 0.00000000e+00],
       [1.00000000e+00, 5.40273858e-02, 1.51650088e-01, 0.00000000e+00,
        1.00000000e+00, 6.21058694e-01, 5.94487888e-04, 2.62489008e-01,
        7.05602436e-01, 1.00000000e+00, 7.21583397e-01, 2.52645551e-01,
        1.00000000e+00, 1.64927351e-04, 1.00000000e+00, 8.50199882e-10,
        5.31581714e-01, 0.00000000e+00, 4.24930773e-01, 1.00000000e+00,
        1.74178816e-01, 1.00000000e+00, 1.00000000e+00, 9.25819980e-01,
        1.00000000e+00, 1.00000000e+00, 7.10783003e-01, 2.38598855e-02],
       [1.29854075e-01, 3.68072346e-03, 3.12879515e-02, 1.46362012e-04,
        3.00978717e-02, 5.00730012e-02, 1.33110302e-02, 1.14980251e-01,
        4.00726025e-03, 1.49724841e-02, 2.99440003e-02, 1.10569669e-02,
        1.07195942e-05, 5.77660268e-03, 1.00000000e+00, 1.01228183e-02,
        2.34830829e-01, 9.25345490e-01, 7.53145551e-02, 1.13546897e-02,
        8.83749801e-02, 3.17535910e-01, 0.00000000e+00, 3.99295232e-02,
        1.10070510e-02, 0.00000000e+00, 6.56735153e-02, 6.05100330e-04],
       [2.97523188e-03, 6.29734920e-03, 1.41516693e-03, 1.96638629e-04,
        3.08878840e-03, 9.00023977e-02, 1.85204691e-02, 1.33518970e-02,
        3.10382062e-03, 3.61596649e-02, 0.00000000e+00, 7.38495239e-02,
        2.19608600e-04, 1.74123890e-03, 1.00000000e+00, 1.05341466e-05,
        0.00000000e+00, 9.23519131e-01, 3.01608445e-02, 0.00000000e+00,
        1.03903584e-03, 4.26741719e-02, 5.56524732e-03, 0.00000000e+00,
        4.62104603e-02, 1.54201911e-01, 0.00000000e+00, 0.00000000e+00],
       [1.14535242e-04, 0.00000000e+00, 2.38379219e-05, 2.34577964e-06,
        7.92642302e-02, 1.14958253e-03, 4.33216630e-03, 1.44780434e-02,
        5.64552090e-03, 3.02162615e-02, 1.43900126e-02, 1.23116592e-01,
        2.09199337e-05, 4.75046659e-04, 1.00000000e+00, 0.00000000e+00,
        1.09568696e-04, 9.22474406e-01, 1.14317752e-01, 2.00486232e-02,
        1.52129573e-04, 0.00000000e+00, 0.00000000e+00, 1.64409378e-02,
        2.66473317e-02, 3.75389223e-03, 6.00361687e-03, 1.17388539e-05],
       [1.22396757e-03, 1.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        3.25909677e-02, 1.83675408e-02, 3.48234958e-02, 7.26948680e-02,
        4.98644073e-02, 1.27746565e-02, 1.80162938e-04, 1.19618785e-01,
        9.61646279e-05, 1.17811415e-05, 1.00000000e+00, 9.22074080e-04,
        4.09924927e-01, 9.22415982e-01, 3.32321444e-02, 2.13680593e-02,
        1.82387691e-02, 0.00000000e+00, 5.02314309e-07, 9.73318299e-02,
        2.09672840e-02, 0.00000000e+00, 2.10699437e-02, 1.49939282e-03],
       [4.97427239e-01, 4.63787900e-01, 2.14420006e-02, 0.00000000e+00,
        3.79694126e-02, 0.00000000e+00, 8.51221007e-02, 6.93142498e-04,
        4.30852666e-02, 1.42149405e-02, 5.37458993e-02, 2.95120607e-02,
        1.13908190e-05, 3.12445411e-04, 1.00000000e+00, 0.00000000e+00,
        1.64243681e-04, 9.22537366e-01, 3.80730508e-02, 5.71566821e-02,
        2.49279446e-02, 1.50269362e-02, 1.24486590e-03, 9.54693236e-02,
        3.91292578e-02, 9.82228232e-01, 1.79861819e-02, 0.00000000e+00],
       [6.83185251e-02, 2.07765139e-04, 1.66174386e-02, 0.00000000e+00,
        2.46454734e-02, 1.05619145e-03, 2.15612709e-02, 9.95882936e-02,
        2.95240113e-02, 5.74595955e-02, 6.20013459e-02, 1.92998231e-02,
        1.86142652e-08, 9.40225387e-03, 1.00000000e+00, 5.61166270e-04,
        3.62169988e-03, 9.23606726e-01, 8.79685445e-02, 1.56105102e-02,
        2.09301008e-02, 0.00000000e+00, 3.02656700e-06, 2.72933796e-02,
        5.60200079e-02, 0.00000000e+00, 3.37717824e-02, 0.00000000e+00],
       [1.86204989e-02, 2.54045624e-03, 1.68593824e-02, 0.00000000e+00,
        1.24811369e-02, 8.57603909e-02, 3.18599562e-03, 9.39539908e-02,
        5.50593358e-02, 1.68727578e-02, 5.55021458e-05, 6.39370958e-02,
        1.26132484e-04, 1.75310971e-02, 1.00000000e+00, 0.00000000e+00,
        1.15974810e-01, 9.24638982e-01, 5.72684123e-02, 2.72345146e-03,
        3.98931077e-02, 1.29441406e-01, 0.00000000e+00, 1.60654499e-02,
        3.34304918e-02, 1.29379525e-01, 1.31762752e-02, 0.00000000e+00],
       [5.39866637e-02, 1.01618250e-02, 5.36334391e-02, 0.00000000e+00,
        5.19920783e-02, 0.00000000e+00, 1.47816193e-02, 7.84947010e-03,
        2.91576359e-04, 6.98026791e-02, 1.23393163e-02, 1.10314396e-01,
        3.05607339e-06, 6.68863392e-01, 1.00000000e+00, 8.94512300e-03,
        2.74164938e-01, 9.75634989e-01, 7.53980639e-02, 1.41027028e-02,
        3.27059290e-02, 0.00000000e+00, 3.00402274e-03, 5.40874452e-02,
        5.90167192e-03, 2.52723854e-02, 2.24714011e-02, 3.21517377e-02],
       [8.81597531e-02, 0.00000000e+00, 2.04109594e-02, 0.00000000e+00,
        3.14068803e-02, 1.35497106e-02, 1.91138731e-01, 3.36800889e-03,
        5.28725131e-03, 9.86190529e-03, 1.13987532e-02, 4.10397796e-02,
        4.87110316e-03, 9.59178064e-05, 1.00000000e+00, 3.29877554e-05,
        8.81284667e-03, 9.22425917e-01, 4.13502574e-02, 4.22125142e-02,
        3.58454202e-02, 4.52446793e-02, 1.43296957e-03, 5.51388884e-02,
        2.68748107e-02, 0.00000000e+00, 2.64280590e-04, 1.00000000e+00],
       [3.66515716e-01, 1.54537058e-03, 1.11198586e-02, 8.21531017e-04,
        1.18326454e-01, 3.83256765e-02, 1.36536718e-02, 2.51232817e-04,
        1.00000000e+00, 2.91909059e-02, 1.00000000e+00, 1.31532440e-01,
        0.00000000e+00, 9.81513357e-06, 1.00000000e+00, 2.18424342e-04,
        2.74737998e-01, 9.22465896e-01, 4.22516952e-02, 1.13922333e-01,
        2.22064072e-02, 0.00000000e+00, 4.06501201e-07, 1.30307932e-01,
        8.96385263e-02, 2.03989642e-03, 2.20132367e-02, 0.00000000e+00]])

After:

tensor([[0.8563, 0.0211, 1.0000, 1.0000, 0.0000, 1.0000, 0.7103, 1.0000, 0.4670,
         0.6893, 0.4964, 1.0000, 0.0006, 0.1857, 1.0000, 1.0000, 0.0000, 0.9568,
         1.0000, 0.2950, 1.0000, 0.1504, 0.0016, 1.0000, 0.5403, 0.0000, 1.0000,
         0.0000],
        [0.0163, 0.0003, 0.0019, 0.0000, 0.2641, 0.0212, 0.0019, 0.0001, 0.0004,
         0.1162, 0.0751, 0.2633, 0.0000, 0.0000, 1.0000, 0.0000, 0.0592, 0.9224,
         0.3196, 0.0236, 0.0010, 0.0000, 0.0001, 0.3886, 0.1352, 0.0000, 0.0001,
         0.0000],
        [0.4228, 0.0041, 0.0463, 0.0011, 0.0396, 0.0000, 0.1788, 0.2065, 0.0101,
         0.0168, 0.0549, 0.0181, 0.0012, 0.0002, 1.0000, 0.0000, 0.0002, 0.9225,
         0.0125, 0.0865, 0.0368, 0.3406, 0.0206, 0.0768, 0.0281, 0.0111, 0.0935,
         0.0000],
        [0.0000, 0.0000, 0.0001, 0.0000, 0.0215, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0040, 0.0088, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.1970, 0.9224,
         0.0220, 0.0000, 0.0000, 0.0000, 0.0000, 0.0222, 0.0000, 0.0000, 0.0000,
         0.0000],
        [1.0000, 0.0540, 0.1517, 0.0000, 1.0000, 0.6211, 0.0006, 0.2625, 0.7056,
         1.0000, 0.7216, 0.2526, 1.0000, 0.0002, 1.0000, 0.0000, 0.5316, 0.0000,
         0.4249, 1.0000, 0.1742, 1.0000, 1.0000, 0.9258, 1.0000, 1.0000, 0.7108,
         0.0239],
        [0.1299, 0.0037, 0.0313, 0.0001, 0.0301, 0.0501, 0.0133, 0.1150, 0.0040,
         0.0150, 0.0299, 0.0111, 0.0000, 0.0058, 1.0000, 0.0101, 0.2348, 0.9253,
         0.0753, 0.0114, 0.0884, 0.3175, 0.0000, 0.0399, 0.0110, 0.0000, 0.0657,
         0.0006],
        [0.0030, 0.0063, 0.0014, 0.0002, 0.0031, 0.0900, 0.0185, 0.0134, 0.0031,
         0.0362, 0.0000, 0.0738, 0.0002, 0.0017, 1.0000, 0.0000, 0.0000, 0.9235,
         0.0302, 0.0000, 0.0010, 0.0427, 0.0056, 0.0000, 0.0462, 0.1542, 0.0000,
         0.0000],
        [0.0001, 0.0000, 0.0000, 0.0000, 0.0793, 0.0011, 0.0043, 0.0145, 0.0056,
         0.0302, 0.0144, 0.1231, 0.0000, 0.0005, 1.0000, 0.0000, 0.0001, 0.9225,
         0.1143, 0.0200, 0.0002, 0.0000, 0.0000, 0.0164, 0.0266, 0.0038, 0.0060,
         0.0000],
        [0.0012, 1.0000, 0.0000, 0.0000, 0.0326, 0.0184, 0.0348, 0.0727, 0.0499,
         0.0128, 0.0002, 0.1196, 0.0001, 0.0000, 1.0000, 0.0009, 0.4099, 0.9224,
         0.0332, 0.0214, 0.0182, 0.0000, 0.0000, 0.0973, 0.0210, 0.0000, 0.0211,
         0.0015],
        [0.4974, 0.4638, 0.0214, 0.0000, 0.0380, 0.0000, 0.0851, 0.0007, 0.0431,
         0.0142, 0.0537, 0.0295, 0.0000, 0.0003, 1.0000, 0.0000, 0.0002, 0.9225,
         0.0381, 0.0572, 0.0249, 0.0150, 0.0012, 0.0955, 0.0391, 0.9822, 0.0180,
         0.0000],
        [0.0683, 0.0002, 0.0166, 0.0000, 0.0246, 0.0011, 0.0216, 0.0996, 0.0295,
         0.0575, 0.0620, 0.0193, 0.0000, 0.0094, 1.0000, 0.0006, 0.0036, 0.9236,
         0.0880, 0.0156, 0.0209, 0.0000, 0.0000, 0.0273, 0.0560, 0.0000, 0.0338,
         0.0000],
        [0.0186, 0.0025, 0.0169, 0.0000, 0.0125, 0.0858, 0.0032, 0.0940, 0.0551,
         0.0169, 0.0001, 0.0639, 0.0001, 0.0175, 1.0000, 0.0000, 0.1160, 0.9246,
         0.0573, 0.0027, 0.0399, 0.1294, 0.0000, 0.0161, 0.0334, 0.1294, 0.0132,
         0.0000],
        [0.0540, 0.0102, 0.0536, 0.0000, 0.0520, 0.0000, 0.0148, 0.0078, 0.0003,
         0.0698, 0.0123, 0.1103, 0.0000, 0.6689, 1.0000, 0.0089, 0.2742, 0.9756,
         0.0754, 0.0141, 0.0327, 0.0000, 0.0030, 0.0541, 0.0059, 0.0253, 0.0225,
         0.0322],
        [0.0882, 0.0000, 0.0204, 0.0000, 0.0314, 0.0135, 0.1911, 0.0034, 0.0053,
         0.0099, 0.0114, 0.0410, 0.0049, 0.0001, 1.0000, 0.0000, 0.0088, 0.9224,
         0.0414, 0.0422, 0.0358, 0.0452, 0.0014, 0.0551, 0.0269, 0.0000, 0.0003,
         1.0000],
        [0.3665, 0.0015, 0.0111, 0.0008, 0.1183, 0.0383, 0.0137, 0.0003, 1.0000,
         0.0292, 1.0000, 0.1315, 0.0000, 0.0000, 1.0000, 0.0002, 0.2747, 0.9225,
         0.0423, 0.1139, 0.0222, 0.0000, 0.0000, 0.1303, 0.0896, 0.0020, 0.0220,
         0.0000]])

Hi,

I think they both contain the same thing. The difference is just how many digits are printed when the tensor is displayed.

I didn’t mention at first it but I compared my results to same model in Keras and it seems that there is a significant difference. I started investigating and have found this problem.

You can extract the first element of these matrices as python numbers and print them. You will see that they contain exactly the same thing.

1 Like

Also torch.set_printoptions(precision=8) might be helpful to have a quick look at the data.

5 Likes

Hello,

@andreiliphd Is there any followup on your issue?

I am using from_numpy to convert a Keras checkpoint to Pytorch and I am also experiencing slight changes in the outputs of each layer (despite the weights being the same as Keras), which propagate and affect the performance of the model substantially. I suspect this problem may have to do with precision but I can’t figure it out.

Let me know if you have solved your problem. Thanks!

Hi, @Rodrigo_Mira!

I am happy PyTorch user now and I didn’t investigate problem that much.

All the best,
Andrei

Thank you for the response @andreiliphd.

In any case, I still want to try to solve this issue.

@ptrblck I used torch.set_printoptions(precision=8). At first glance it appears the weights are the same after being converted from torch to numpy. But if I print the norm of each weight tensor in both formats (numpy and torch) I get small differences such as:

NP : 13.187959
Torch: tensor(13.18795586)

These are definitely not equivalent. This seemed fishy to me so I printed using precision 32 for numpy and torch and I got this:

NP: 0.071544915
Torch: 0.07154491543769836425781250000000

So the Torch tensor appears to have a lot of garbage which was not in the original Numpy vector. Both types are float32 so the Torch tensor should not even be able to hold so many digits (I think).

It’s all pretty weird, and using this checkpoint, the outputs of each layer in the model are slightly different. These differences propagate into a very different final performance.

Am I missing something here? Would really appreciate any help you can spare.

Thanks a lot,
Rodrigo

Edit: My code is basically just:
torch_weight = torch.from_numpy(numpy_weight)

1 Like

FP32 precision is usually limited to ~1e-6, so your parameter differences should be in this range between both implementations.
If you are using a lot of modules, these differences might accumulate and change the final output by a larger magnitude.
Could you print the differences for each layer and activation?
This could yield some more information, if this issue is really related to floating point precision or maybe some other bug.

Thanks a lot for the response @ptrblck.

So the models are as follows.
Pytorch:
(cnn): Sequential(
(0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 3), padding=(1, 1))
(5): ReLU()
(6): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU()
(8): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU()
(10): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 3), padding=(1, 1))
(11): ReLU()
(12): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU()
(14): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU()
(16): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 3), padding=(1, 1))
(17): ReLU()
(18): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(19): ReLU()
(20): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(21): ReLU()
(22): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 3), padding=(1, 1))
(23): ReLU()
)
Keras:
conv1 = (Conv2D(16, (3,3), strides=(1, 1), activation=‘relu’, padding=‘same’))(re_input)
conv_out_1=conv1
conv1 = (Conv2D(16, (3,3), strides=(1, 1), activation=‘relu’, padding=‘same’))(conv1)
conv_out_2=conv1
conv1 = (Conv2D(16, (3,3), strides=(1, 3), activation=‘relu’, padding=‘same’))(conv1)
conv_out_3=conv1

    conv2 = (Conv2D(32, (3,3), strides=(1, 1), activation='relu', padding='same'))(conv1)
    conv_out_4=conv2
    conv2 = (Conv2D(32, (3,3), strides=(1, 1), activation='relu', padding='same'))(conv2)
    conv_out_5=conv2
    conv2 = (Conv2D(32, (3,3), strides=(1, 3), activation='relu', padding='same'))(conv2)
    conv_out_6=conv2
    
    conv3 = (Conv2D(64, (3,3), strides=(1, 1), activation='relu', padding='same'))(conv2)
    conv_out_7=conv3
    conv3 = (Conv2D(64, (3,3), strides=(1, 1), activation='relu', padding='same'))(conv3)
    conv_out_8=conv3
    conv3 = (Conv2D(64, (3,3), strides=(1, 3), activation='relu', padding='same'))(conv3)
    conv_out_9=conv3
    
    conv4 = (Conv2D(128, (3,3), strides=(1, 1), activation='relu', padding='same'))(conv3)
    conv_out_10=conv4
    conv4 = (Conv2D(128, (3,3), strides=(1, 1), activation='relu', padding='same'))(conv4)
    conv_out_11=conv4
    conv4 = (Conv2D(128, (3,3), strides=(1, 3), activation='relu', padding='same'))(conv4)
    conv_out_12=conv4

Messy formatting, sorry, but they should be 100% equivalent. The outputs are as follows (I print the mean and std of the output vectors at each of the 12 Conv layers):
Pytorch:
mean 0.1638 std 0.2410
mean 0.1336 std 0.2011
mean 0.1350 std 0.1917
mean 0.0607 std 0.1112
mean 0.0305 std 0.0819
mean 0.0276 std 0.0563
mean 0.0128 std 0.0345
mean 0.0222 std 0.0392
mean 0.0333 std 0.0613
mean 0.0319 std 0.0750
mean 0.1114 std 0.1933
mean 0.3793 std 0.6193
Keras:
mean 0.1639 std 0.2410
mean 0.1339 std 0.2010
mean 0.1360 std 0.1920
mean 0.0608 std 0.1117
mean 0.0305 std 0.0825
mean 0.0277 std 0.0565
mean 0.0130 std 0.0349
mean 0.0222 std 0.0395
mean 0.0337 std 0.0616
mean 0.0299 std 0.0726
mean 0.1088 std 0.1867
mean 0.3760 std 0.6399

As you can see, the difference is very small in the beginning but propagates into something larger. I also have feed the outputs of this into a BLSTM and then into FC layers so the error is propagated even further, but I want to deal with this issue first. Do you believe there is any way to fix this? Thanks!

Edit: BTW, I compared the outputs after the first layer, without activation, and they are indeed also not the same:
Pytorch:
mean -0.0057 std 0.4149
Keras:
mean -0.0054 std 0.4144

Any ideas? Would really appreciate some followup on this matter. @ptrblck
Thanks.

Hi,

Unfortunately the small errors are magnified by the following layers. So the slightest difference will lead to completely different networks.
Having bit-perfect equivalence of the operations is not easy, even within pytorch, we don’t guarantee that by default between two runs of the same code (see determinism notes in the doc). So it will be almost impossible to do with Keras.

Hello @albanD,

Ok, I totally understand. I ended up training my own model on PyTorch and it’s working well. In any case, I really appreciate the feedback and will look into the notes you mentioned to further my understanding of this issue. Thanks a lot and keep up the good work!

1 Like

So, finally how can we solve this problem?

I use pytorch to accurate the large matrix operation rather than training deep net. However, the precision garbage brings the propagated error, which make the final result fail.

Hello, I think I may have came across the same issue, I had longitutde and latitude and had different values when using numpy and using torch, the loss for a numpy approach of sklearn and a pytorch we’re pretty different due to the nature of my values on it. The issue has been solved by using

torch.tensor(values,dtype=torch.float64)

Instead of my initial

torch.Tensor(values)

Hope it helps