What is the best way to run crf with DDP and be compatible with JIT

Hi, I am trying to implement a crf model like this tutorial. However, when I wrap the model with DDP, it raises the error “AttributeError: ‘DistributedDataParallel’ object has no attribute ‘neg_log_likelihood’.” One possible way is to extract the crf loss calculation out, but that seems not a good approach since: (1) The crf contains viterbi which will also be extracted out from my model; (2) I will have problems when converting my model to JIT.

Could someone help? Thank you in advance.

So what happens is that DDP wraps the model, so the original model is ddp_model.model.
Then DDP has a forward method that essentially does the “distributed-admin” (getting the data to the various cards and collecting results and preparing for backward results to be treated).
Now you could run ddp_model.model.neg_log_likelihood, but that would run on a single instance, defeating the purpose of using DDP.
I think the easiest way to is probably to fold all entry points (that should be distributed) into the forward method with a flag “what” that you check for the various functions.

Best regards