What's the return model by torch.nn.DataParallel

Now, I have a model defined as follows:


class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, x1, x2, x3):
        output1 = self.embedding_net(x1)
        output2 = self.embedding_net(x2)
        output3 = self.embedding_net(x3)
        return output1, output2, output3

    def get_embedding(self, x):
        return self.embedding_net(x)

model=TipleNet(embedding_net)
model=torch.nn.DataParallel(model)

now call the get_embedding function, model.get_embedding(x), an error occurs as follows:

AttributeError: ‘DataParallel’ object has no attribute ‘get_embedding’

Now, I want to call the function get_embedding, what can I do?

try model.module.get_embedding()

1 Like

Waht’s the difference between model.module and model.modules?

module is a reference to the model that is packed inside nn.DataParallel.

.modules() is a method inherited from nn.DataParallel’s parent nn.Module.
Refer this post for more info:

1 Like