Reproducibility of CUDAExtension

Hi everyone!

I am fine tuning the following model using the script provided. However, even if I follow the reproducibility guidelines, I obtain different losses across experiments with the same seed. I used PyTorch Lightning seed_everything function and also tried doing it manually for each random source, as follows:

import torch
import os
import numpy as np
import random as r
os.environ['PYTHONHASHSEED'] = str(1)
np.random.seed(1)
r.seed(1)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

Finally I tried freezing all model and only training the fully connected layer at the end, and in this case it is reproducible, meaning that the source of randomness is inside the backbone. I tried with torchvision’s resnet 18 3D, and it does not happen, so I am quite certain that the problem is with the CUDAExtension the authors provided for the shift module. Is there a way to fix the seed on the CUDA module?

Thank you!

1 Like

quickly glanced at that, it has atomicAdd in cuda kernel, which is non-deterministic

Thanks for the answer. Is there any deterministic alternative to atomicAdd?

AFAIK, no. It is possible to rewrite the kernel to store unaccumulated values and sum them later deterministically (e.g. with torch.sum, or maybe cublas), but it would take some effort.

To expand a bit on Alex answer:

  • In general, what you compute by atomicAdd is an “index add”, a sequential version would look like
      for i in range(...):
          foo[key[i]] += value[i]
    
    The non-determinism comes from the fact that the order in which the additions are computed in the atomicAdd is non-deterministic (in contrast to that for loop).
  • Now, an “obvious” way to make this deterministic (and also faster if the number of keys leads to lots of conflicts) is to sort keys and values by key and then have each thread process a segment.
    This can get involved, look at embedding bag backward in the PyTorch source code as an example.
    I once thought I might get crowdfunding to get deterministic kernels for PyTorch, but it didn’t work out. But Kurt Mohler from Quansight implemented a mechanism to flag nondeterminism. I don’t know if he also will write alternatives to the kernels.
  • The nondeterminism comes addition being non-commutative for floats (as in “1e30 + 1 - 1e30 = 0”).
    One conceptually easy option to avoid that could be to quantize the computation to fixed precision. (Compute once to find the maximum modulus and then compute with that precision or so) - integer addition commutes, so atomiAdd for integers is deterministic.
  • If you can live with “enhanced approximate determinism” one option might be to do the atomic add computation in double and then round back to float. The larger precision of double means that the non-determinism might just be rounded away, or at least a large part of it.

Best regards

Thomas

3 Likes

Thanks Alex for your help and thanks @tom for your detailed answer, I will explore your suggestions. I have never written any CUDA code, but I found the occurrences of atomicAdd in the code. For example:

atomicAdd(&shift_grad_buffer[0][C_idx][H_idx][W_idx], pixel_H_grad);

I tried your enhanced approximate determinism suggestion with the following:
shift_grad_buffer[0][C_idx][H_idx][W_idx] = (float) (shift_grad_buffer[0][C_idx][H_idx][W_idx] + pixel_H_grad); but it seems to do the same as atomicAdd. I used the same notation as in C for a type casting and I added both terms. I am sure that it is quite possible that my solution is not correct. I will try to debug the CUDA code to see what happens. Thanks!

No, this isn’t how it works. The atomic add is needed to remove the data race.
From a brief look, if you invoke the function with doubles, the computation should be done in doubles, so you could just cast to double, run the kernel then cast back. Note that double computation is much slower than single precision (aka float). Still it should give you an idea of how much of an improvement you see in avoiding randomness in your outputs.

Best regards

Thomas

Your problem happens when indexes ([C_idx][H_idx][W_idx]) are not unique across cuda grid, i.e. an algorithm includes some reduction. Non-determinism comes from cuda thread scheduler.

Obvious fix is to have unique write indexes per kernel launch. For example, if reduction is only across batch dimension, you can process batch elements one by one (synchronizing between kernel launches) - this is slower obviously. Or write to big “buffer”, so that you don’t have many-to-one index mappings.

ps atomicAdd(double*) requires compute capability 6.0, there is a workaround with casting to (long long)
pps I wouldn’t suggest doing any of this without cuda c++ experience, if you can avoid it

I guessed it wouldn’t be so straightforward @tom . I have used C++ but not for a long time but I understand your warning Alex, I will not try to touch the kernel CUDA code if I am not sure about what I am doing and I don’t see myself qualified to test your solution. However, I managed to use the cast of both parameters (in python, before calling to the kernel, and after) to check the randomness reduction, as @tom suggested, but it is not yet reproducible, because during the training, the loss is quite similar, but there are peak differences, less than before, but still a little. I think that I will use this change only in the cases where the backbone has to be trained and keep it frozen otherwise.

Thanks both for your help, I have learnt a lot about the C++/CUDA PyTorch API :smiley: