Now, I have a model defined as follows:
class TripletNet(nn.Module):
def __init__(self, embedding_net):
super(TripletNet, self).__init__()
self.embedding_net = embedding_net
def forward(self, x1, x2, x3):
output1 = self.embedding_net(x1)
output2 = self.embedding_net(x2)
output3 = self.embedding_net(x3)
return output1, output2, output3
def get_embedding(self, x):
return self.embedding_net(x)
model=TipleNet(embedding_net)
model=torch.nn.DataParallel(model)
now call the get_embedding function, model.get_embedding(x)
, an error occurs as follows:
AttributeError: ‘DataParallel’ object has no attribute ‘get_embedding’
Now, I want to call the function get_embedding
, what can I do?