Prevent duplicated copy of tensor to CPU

Hello. This is about extending save_on_cpu: torch.autograd.graph — PyTorch 2.0 documentation

In my case, I found the same tensor got saved to CPU twice due to some branches in the forward path, and I like to avoid it. Hence, I did the following, For an incoming tensor X,

global_tensor_on_cpu = {}

key_x = key_func(X)
if key_x not in global_tensor_cpu:
   global_tensor_on_cpu[key_x] = x.cpu()
 
return key_x

But, I found once I start applying this, the model accuracy drops for some reason. I wonder if there can be any side-effect with doing this. Also, as a follow-up question, is there any good “key_func” to get the unique id of a tensor? I’m using a mix of the shape, data_ptr, and dististribution info, but it seems a bit expensive to compute.