Native pytorch activation checkpointing implementation gives different performance than without

I am using activation checkpointing in order to trade compute for memory. The idea is to run the forward in a torch.no_grad() manner and to run it again before performing the backpropagation to get the necessary activations.

I was therefore expecting to get the same loss and performances values compared to a training done without it given that my script is deterministic.

I also tried the activation checkpointing wrapper of fairscale but I could not get the same values.