Check if model is wrapped in nn.DataParallel

Hi,

I wondered if there is an efficient way to check if a model is wrapped in nn.DataParallel. I get a lot errors due to DataParallel objects being wrapped in module object and wondered if there is a more natural way of checking data parallelism than len(list(Model.named_children())) == 1.

2 Likes

I am using something like this for getting the model’s state dict (for saving):

try:
    model_state_dict = model.module.state_dict()
except AttributeError:
    model_state_dict = model.state_dict()

It essentially checks if the model has an attribute module, which is created by nn.DataParallel. I don’t know if there is a cleaner way introduced in more recent versions.

P.S: you have written DataLoader in your post title, you might attract more readers / helpers by correcting that :sweat_smile:

2 Likes

I thought of doing something similar as well, but in that case I will still have to repeat all lines of code following try in except as well which I thought was a bad practice to follow in general. But still thanks for your insight.

1 Like

isinstance(m, nn.DataParallel)

1 Like

This is ingenious, thank you!