Runtime error raised in DDP when using .detach() to skip gradient computation in some DP ranks

Describe

I’m using DDP training network based on Transformer on several nodes. There is a part in my the MHA layer code:

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )

        if self.layer_id in pp_fail:
            hidden_states = residual + hidden_states.detach()
        else:
            hidden_states = residual + hidden_states

which simulate the situation that if a nodes fails, the neighbor will undertake its tasks but skip gradient computation in MHA of these layers, and the gradient of corresponding parameters is computed by other normal node which is in different data parallelism ranks. I apply DDP to accelerate the training, which means that pp_fail varies from DP ranks.

I started training with find_unused_parameters=True. But I found error:

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that

were not used in producing loss. Since `find_unused_parameters=True` is enabled, this likely  means that not all `forward` outputs participate in computing loss. You

can fix this by making sure all `forward` function outputs participate in calculating loss.

[rank1]: If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `

forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iter

able).

[rank1]: Parameters which did not receive grad for rank 1: model.layers.8.input_layernorm.weight, model.layers.8.self_attn.v_proj.weight, model.layers.8.self_attn.k_p

roj.weight, model.layers.7.input_layernorm.weight, model.layers.7.self_attn.v_proj.weight, model.layers.7.self_attn.k_proj.weight, model.layers.6.input_layernorm.weig

ht, model.layers.6.self_attn.v_proj.weight, model.layers.6.self_attn.k_proj.weight, model.layers.5.input_layernorm.weight, model.layers.5.self_attn.v_proj.weight, mod

el.layers.5.self_attn.k_proj.weight, model.layers.4.input_layernorm.weight, model.layers.4.self_attn.v_proj.weight, model.layers.4.self_attn.k_proj.weight, model.laye

rs.3.input_layernorm.weight, model.layers.3.self_attn.v_proj.weight, model.layers.3.self_attn.k_proj.weight

[rank1]: Parameter indices which did not receive grad for rank 1: 29 30 35 38 39 44 47 48 53 56 57 62 65 66 71 74 75 80

I suppose the reason is that DDP try to collect gradients from all parameters among all DP ranks but found that there are no gradients for some parameters due to detach(). But I don’t know how to deal this error. What should I do? Thanks in Advance.

Version

python 3.10.12
torch 2.3.0+cu121
torchaudio 2.3.0+cu121
torchvision 0.18.0+cu121

I don’t know how exactly you would like to create the gradient for the unused parameters, but instead of detaching the output on a rank maybe you could zero out the corresponding gradients on this particular rank before the backward pass allreduces the gradients?

Actually I want to get the average gradient among the nodes where the computation graphs go normally, and scatter it to all ranks. For example, if rank 0 run into hidden_states = residual + hidden_states.detach() and rank 1, rank 2, rank 3 run into hidden_states = residual + hidden_states, I want to compute the average gradient among rank 1, rank 2 and rank 3 and scatter the result to all ranks.

I think zero out the gradients rather than detach is a choice, and it should scale the average gradient due to I don’t want to compute average on all ranks. But I still want to know if there’s a way to achieve my goal using the original method.