Cannot execute loss.backward() for training a specific layer

Hello, my code works when I set requires_grad for other layers, but it suddenly stops working with the specific layer. (Still working when I only enable one GPU)

I am training Transformers using FullyShardedDataParallel (FSDP) with ShardingStrategy.HYBRID_SHARD (full-shard), limit_all_gathers=True, use_orig_params=True.

Currently, my loss function is as follows.

teacher_hidden_states = teacher_output.hidden_states[target_layer_idx+1].detach()
                        student_hidden_states = output.hidden_states[target_layer_idx+1]
loss = F.mse_loss(student_hidden_states, teacher_hidden_states)

If I add additional CE loss on final prediction, it suddenly works.

The error message I got is as follows.

[rank1]:     loss.backward()                                                                                                                                                                                       
[rank1]:   File "/root/anaconda3/envs/my_env/lib/python3.10/site-packages/torch/_tensor.py", line 521, in backward                                                                                            
[rank1]:     torch.autograd.backward(                                                                                                                                                                              
[rank1]:   File "/root/anaconda3/envs/my_env/lib/python3.10/site-packages/torch/autograd/__init__.py", line 289, in backward                                                                                  
[rank1]:     _engine_run_backward(                                                                                                                                                                                 
[rank1]:   File "/root/anaconda3/envs/my_env/lib/python3.10/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward                                                                         
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass                                                                                                 
[rank1]:   File "/root/anaconda3/envs/my_env/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 682, in _pre_backward_hook                                                          
[rank1]:     _prefetch_handle(state, handle, _PrefetchMode.BACKWARD)                                                                                                                                               
[rank1]:   File "/root/anaconda3/envs/my_env/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1222, in _prefetch_handle                                                           
[rank1]:     _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)                                                                                                                             
[rank1]:   File "/root/anaconda3/envs/my_env/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 300, in _unshard                                                                    
[rank1]:     handle.unshard()                                                                                                                                                                                      
[rank1]:   File "/root/anaconda3/envs/my_env/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1312, in unshard                                                                       
[rank1]:     self._use_unsharded_flat_param(padded_unsharded_flat_param)                                                                                                                                           
[rank1]:   File "/root/anaconda3/envs/my_env/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1443, in _use_unsharded_flat_param                                                     
[rank1]:     self._use_unsharded_views(                                                                                                                                                                            
[rank1]:   File "/root/anaconda3/envs/openvla-oft/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context                                                                          
[rank1]:     return func(*args, **kwargs)                                                                                                                                                                          
[rank1]:   File "/root/anaconda3/envs/openvla-oft/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1925, in _use_unsharded_views
[rank1]:     tensor.data = view                                                                          
[rank1]: AttributeError: 'NoneType' object has no attribute 'data'

It seems like there is some timing issue….

Has anyone encountered similar issues?

Thanks a lot.

It’s definitely not a timing issue. The error indicates that it tried to access the .data attribute from what was likely a parameter tensor; however, it received a None object instead. Sharing your training loop might help us figure this out