I have a custom nn.Module which uses another nn.Module called ‘rnn’. Before wrapping with nn.DataParallel I was able to reach it by model.rnn but after it raises AttributeError: 'DataParallel' object has no attribute 'rnn'
What is the way to reach model attributes inside DataParallel?
I know this post is 3-4 years old, but if anyone runs into this problem hopefully this is of help.
If you want to change an attribute in the case where you want to copy a layer from one model to the other, make sure you wrap the model your copying from in DataParallel (custom) also. I was also wondering if this was possible. Just make sure to also wrap the model that you are copying from in DataParallel. I made a mistake in not doing so, resulting in lot of wasted time…
# simple fix for dataparallel that allows access to class attributes
class MyDataParallel(torch.nn.DataParallel):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)
# def __setattr__(self, name, value):
# try:
# return super().__setattr__(name, value)
# except AttributeError:
# return setattr(self.module, name, value)
def load_weights(base_model_name, model, epoch):
"""
Loads previously trained weights into a model given an epoch and the model itself
:param base_model_name: name of base model in training session
:param model: the model to load weights into
:param epoch: what epoch of training to load
:return: the model with weights loaded in
"""
pretrained_dict = torch.load('{}/{}_{}_{}.pt'.format(args.model_path, base_model_name, epoch, args.lr))['state_dict']
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
return model
model = r2plus1d50()
model = MyDataParallel(model, device_ids=[0, 1])
# Copying the layer from epoch 13 to model
epoch = 13
head_model = r2plus1d50()
head_model = MyDataParallel(head_model, device_ids=[0, 1])
head_model = load_weights(base_model_name, head_model, epoch)
# head1 is the name of the attribute that we want to copy over
head = getattr(head_model, "head1")
# set attribute to change the attribute
setattr(model.module, "head1", head)
For some reason, I was not able to overwrite the setattr() method for a custom DataParallel class, which is why the model is accessed with model.module.
Hope this helps.