PyTorch not working smoothly with DeepSpeed ZeRO Stage-3

I notice that when using zero stage-3 in deepspeed to train model without any recomputing (checkpointing) methods, the parameters of the model cannot correctly released. I have already release param.data with my ZeRO, which set param.data to torch.Tensor([1]). However, the memory consumption doesn’t decrease, which means param.data still remains in the memory.

I think the problem occurs in autograd module, which may possibly create a weak ref to “param.data” when there are some corresponding intermediate results ( a.k.a. intermediate activations) in GPU memory.

Could anyone tell us how I can remove this param ref to help reduce memory?

@albanD Could you please help me out?

Hi,

In general, you should never use .data :slight_smile:
Could you give more details on what you’re trying to do here? Because it is expected that the autograd saves (a lot) of things in the graph.

I am trying to save memory as much as possible.
The implementation actually comes from DeepSpeed. this line set torch.data to a torch.ones(1) tensor. DeepSpeed/partition_parameters.py at master · microsoft/DeepSpeed · GitHub

However, the parameters memory didn’t release correctly. And I found out that if a parameter Tensor has its corresponding intermediate results, the parameters memory won’t be released.

The parameter data will only be gathered when they are needed. Like gather in pre_sub_module_backward_function here:DeepSpeed/stage3.py at master · microsoft/DeepSpeed · GitHub

But If the parameters memory not release correctly, here gathering won’t consume any other memory. In other word, I reset the param.data to parameter Tensor, but the memory_allocated has no change.

I wouldn’t say that. You do remove the parameter successfully. The problem is that other part of the code are also using this Tensor and so the Tensor cannot be freed. This is actually expected.

If you want to play tricks with what the autograd saves for backward, you have hooks for that: Autograd mechanics — PyTorch 1.10.0 documentation. There is also a tutorial for that.

Thanks, I will try it out later