I have an NLP model and I need to use data parallelism due to large batch data size. I wrapped my model using nn.DataParallel. I need the model attributes else where during training, validation, logging etc. Here’s how I’m doing it:
class Trainer():
def __init__(self, model, config):
self.model = model
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
self.model = nn.DataParallel(model)
self.txt_property = self.model.txt_property
# many more properties
def do_something(self):
param = self.model.param
AttributeError: 'DataParallel' object has no attribute 'txt_property'
What is the best practice to encapsulate model with nn.DataParallel
?