Independence of matrix rows in matmul

Hi,
According to linear algebra, the result of multiplying a matrix by a matrix can be computed separately by rows (or by columns). Why are then the two following result not exactly equivalent? Do I need to enable any other flag/choose a different backend?

import random
import os
import numpy as np
import torch

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

model = torch.randn(1000,1000)
inputs = torch.randn(2,1000)
with torch.inference_mode():
    a = (inputs @ model)[0]
    b = inputs[0] @ model
    print(torch.all(torch.isclose(a,b , 1e-3)))  # True
    print(torch.all(torch.isclose(a,b , 1e-4)))  # False

Different algorithms can be picked depending on the input shape of your workload.
My system returns the same results, but your CPU might take an optimized code paths for one of the operations. Generally you cannot expect to see bitwise-identical results due to the limited floating point precision and would need to set torch.use_deterministic_algorithms to True if needed.

Adding torch.use_deterministic_algorithms(True) does not change the results I’m getting, but increasing floating point precision does make both equal (I’m running on a colab notebook for reproducibility).
Any explanation for why that would be the case? I would think both accumulations (for computing the first output and the second output) happen independently, but apparently not the same thing is happening (even apparently without any non-deterministic algorithms).

Sorry for the unclear description. Enabling deterministic algorithms should return bitwise-identical and deterministic results for the same workload which is defined by e.g. the shapes of the used tensors. Different algorithms can still be picked for differently shaped tensors and a small example also shows this for simple reductions:

x = torch.randn(100, 100, 100)
s1 = x.sum()
s2 = x.sum(0).sum(0).sum(0)
print(s1 - s2)
# tensor(0.0002)

Thanks @ptrblck, I appreciate your answers.
What surprises me and to my understanding differentiates between both examples, is that in the original example the two computations, at least in theory, do not even share any intermediate result (theoretically none of the scalars generated while computing (inputs @ model)[1] takes part in the computation of (inputs @ model)[0] so I do not see where the inconsistency comes from.
In the example you give, I can guess that the difference results from a different order of accumulating the coordinates of x (but I could be wrong here as well) - is that what happens in my example as well? That regardless of the additional inputs @ model)[1], the different coordinates of inputs @ model)[0] are accumulated in a different order than they are in the inputs[0] @ model, resulting in a different rounding error?

But this is also not what you are comparing.
In your example you are executing a large workload, slicing it afterwards, and a small workload by slicing the inputs:

a = (inputs @ model)[0]
b = inputs[0] @ model

Both approaches could take different algorithms, which can use a different order of operations, which will result in the expected error due to the limited floating point precision.

I would recommend to profile the code and to verify my claims:

model = torch.randn(1000,1000).cuda()
inputs = torch.randn(2,1000).cuda()

a = (inputs @ model)[0]
b = inputs[0] @ model
print((a - b).abs().max())
# tensor(2.8610e-05, device='cuda:0')


from torch.profiler import profile, record_function, ProfilerActivity


with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    a = (inputs @ model)[0]
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                            aten::matmul         0.41%       9.000us        98.46%       2.175ms       2.175ms       0.000us         0.00%      21.000us      21.000us             1  
#                                                aten::mm        85.60%       1.891ms        98.05%       2.166ms       2.166ms      21.000us       100.00%      21.000us      21.000us             1  
# void gemmSN_NN_kernel<float, 256, 4, 2, 8, 2, 4, fal...         0.00%       0.000us         0.00%       0.000us       0.000us      21.000us       100.00%      21.000us      21.000us             1  
# cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFla...        11.18%     247.000us        11.18%     247.000us     247.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                        cudaLaunchKernel         1.27%      28.000us         1.27%      28.000us      28.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                            aten::select         1.00%      22.000us         1.09%      24.000us      24.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                        aten::as_strided         0.09%       2.000us         0.09%       2.000us       2.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                   cudaDeviceSynchronize         0.45%      10.000us         0.45%      10.000us      10.000us       0.000us         0.00%       0.000us       0.000us             1  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
# Self CPU time total: 2.209ms
# Self CUDA time total: 21.000us


with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    b = inputs[0] @ model
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                                    -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                            aten::matmul         0.17%      19.000us        97.55%      10.962ms      10.962ms       0.000us         0.00%      10.000us      10.000us             1  
#                                                aten::mm         0.61%      68.000us        97.20%      10.922ms      10.922ms      10.000us       100.00%      10.000us      10.000us             1  
# std::enable_if<!(false), void>::type internal::gemvx...         0.00%       0.000us         0.00%       0.000us       0.000us      10.000us       100.00%      10.000us      10.000us             1  
#                                            aten::select         0.18%      20.000us         0.21%      24.000us      24.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                        aten::as_strided         0.04%       4.000us         0.04%       4.000us       2.000us       0.000us         0.00%       0.000us       0.000us             2  
#                                         aten::unsqueeze         0.07%       8.000us         0.07%       8.000us       8.000us       0.000us         0.00%       0.000us       0.000us             1  
# cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFla...         0.04%       5.000us         0.04%       5.000us       5.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                        cudaLaunchKernel        96.55%      10.849ms        96.55%      10.849ms      10.849ms       0.000us         0.00%       0.000us       0.000us             1  
#                                          aten::squeeze_         0.08%       9.000us         0.12%      13.000us      13.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                       aten::as_strided_         0.04%       4.000us         0.04%       4.000us       4.000us       0.000us         0.00%       0.000us       0.000us             1  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
# Self CPU time total: 11.237ms
# Self CUDA time total: 10.000us

As you can see two different kernels are used: gemmSN_NN_kernel and internal::gemvx.

This is again very useful, I did not know you could see these internal calls with a profiler and this is exactly what I was missing.
Thanks a lot!