I am creating a wrapper Module over a base
nn.Module
class SiameseModel(nn.Module):
def __init__(self, base):
super().__init__()
self.base = base
def forward(self, img1, img2):
output1 = self.base.forward_features(img1)
output2 = self.base.forward_features(img2)
I parallelize this model during training and save it along with the base.
model = SiameseModel(base).to(device)
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
torch.save(model.state_dict(), os.path.join(save_path, "epoch_0.pt"))
torch.save(model.module.base.state_dict(), os.path.join(save_path, "base_epoch_0.pt"))
Now during inference, I intend to use only the base and its specific forward_features
method, so I can’t do something like base(input)
, but have to use base.forward_features(input)
.
In an effort to parallelize the base during inference, I do something like -
base = get_base(args.base, pretrained=False) #Instantiating the base class
base.load_state_dict(torch.load(save_path)) #loading saved base from the last block
if torch.cuda.device_count() > 1:
base = torch.nn.DataParallel(base)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base.to(device)
base.eval()
But because I have to use the forward_features
method, I eventually end up doing
embeddings = base.module.forward_features(input)
and only 1 GPU gets used. My hunch is that using .module.forward_features
makes it use only 1 GPU.
I have tried some other things, like saving the parallelized base, and loading it, but eventually have to use .module.forward_features
. I also tried some random things like making base.forward = base.forward_features. But none seem to make use of all the GPUs.
What would you suggest in this case?