Why is torch.mm non deterministic?

I get the following error (this error has been referenced many times on the internet, but putting it here anyways) -

Traceback (most recent call last):
  File "main.py", line 79, in <module>
  File "main.py", line 69, in main
    distillation = Distillation(train_loader, train_dataset, model_wt_path, config_dict, args.seed) # passing the mean and stddev of the dataset for the images, this is more specific to images when done in this context
  File "/x0/megh98/projects/ddist/distillation.py", line 44, in __init__
    train_acc = self.run_validation(self.train_loader)
  File "/x0/megh98/projects/ddist/distillation.py", line 66, in run_validation
    logits = self.net(img)
  File "/x0/megh98/anaconda3/envs/ddist/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/x0/megh98/projects/ddist/models/networks.py", line 67, in forward
    out = self.classifier(out)
  File "/x0/megh98/anaconda3/envs/ddist/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/x0/megh98/anaconda3/envs/ddist/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, go to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility

when I set - torch.use_deterministic_algorithms(True). This is no surprise. When I went to the docs here - torch.use_deterministic_algorithms — PyTorch 1.12 documentation , it mentions that torch.mm is a non deterministic operation. So, I have the following 2 questions -

  1. Why would a simple matrix multiplication such as torch.nn be a non-deterministic operation?
  2. From the error traceback it seems like the non-determinism is coming from the F.linear function, which after looking at the pytorch GitHub seems to use torch.matmul which inturn uses torch.mm atleast according to my understanding - and this is what is leading to the non determinism. Am I right in my understanding here?

the reasoning is given in both the error message and the documentation and it points to this link: cuBLAS :: CUDA Toolkit Documentation

Thanks @smth - yes, I did read the explanation in the link earlier, but it seems like the explanation is a general one where they mention stuff about cuda streams etc. But, I wanted to know what is specific about the torch.mm function that makes it non-deterministic, is it related the explanation given here - cuBLAS :: CUDA Toolkit Documentation - is it something related to the parallelism of the matrix multiplies?

yes, the non determinism comes from torch.mm calling cublas and cublas being non deterministic in these cases