I’ve been stuck on this issue for some time and need some guidance.
My network contains a moduledict, where each key refers to a city and the value is a classifier for that city. Depending on the data input to the model, different classifiers are called and trained. This model works fine on a single-gpu setting. However, when I switch to multi GPUs the training hangs at the end of the epoch. My guess is this is due to the fact that based on the data present in the batch different classifiers are being called in different GPUs(this I know is happening). Hence, since the network graph is different across the GPUs the model is not able synchronize properly. Is my assumption here correct ? If so why does it hang only at the end of the epoch and what can I do to fix this issue.
Thanks