DDP error due to loss only depending on forward hook's output

I’m training a model whose loss is a function of some very intermediate output. As I import the model from another package, I don’t want to change anything in the package, e.g., modify its forward method to output the intermediate tensor I want. That can result in many changes to the package, and is not scalable.

So instead, I register a forward hook to the model’s submodule that produces the intermediate output. Code is lengthy, so I put here a snippet that should be sufficient to show the idea.

from the_package import MODEL
model = MODEL() # suppose I have created the model by calling that package
optimizer = Adam(model.parameters(), lr=1e-3)
# register hook, and store the intermediate output
intermediate = None
def hook_function(module, input, output):
    intermediate = output
my_submodule = locate_my_submodule(model) # some function that locates the submodule
model_output = model.forward(input_batch) # In fact, model_output is not used
loss = loss_function(intermediate, labels_or_targets)

The code is runnable on single GPU. But if I wrap the model and optimizer with DDP on 8xV100, it fails after the 1st iteration. And I get the following error:

Note I have set find_unused_parameters=True by referring to posts like here , and here

One get-around might be building a “dummy” loss on the actual model_output. Then add this “dummy” loss to total_loss but with a weight multiplier of 0. But that can cause unwanted overhead. Is there a better solution?