Hard time data parallelizing a submodule

I am creating a wrapper Module over a base nn.Module

class SiameseModel(nn.Module):
    def __init__(self, base):
        super().__init__()
        self.base = base
    
    def forward(self, img1, img2):
        output1 = self.base.forward_features(img1)
        output2 = self.base.forward_features(img2)

I parallelize this model during training and save it along with the base.

model = SiameseModel(base).to(device)
if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)

torch.save(model.state_dict(), os.path.join(save_path, "epoch_0.pt"))
torch.save(model.module.base.state_dict(), os.path.join(save_path, "base_epoch_0.pt"))

Now during inference, I intend to use only the base and its specific forward_features method, so I can’t do something like base(input), but have to use base.forward_features(input).

In an effort to parallelize the base during inference, I do something like -

base = get_base(args.base, pretrained=False) #Instantiating the base class 
base.load_state_dict(torch.load(save_path)) #loading saved base from the last block
if torch.cuda.device_count() > 1:
    base = torch.nn.DataParallel(base)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base.to(device)
base.eval()

But because I have to use the forward_features method, I eventually end up doing

embeddings = base.module.forward_features(input)

and only 1 GPU gets used. My hunch is that using .module.forward_features makes it use only 1 GPU.

I have tried some other things, like saving the parallelized base, and loading it, but eventually have to use .module.forward_features. I also tried some random things like making base.forward = base.forward_features. But none seem to make use of all the GPUs.

What would you suggest in this case?

If I try reassigning the forward method before making the base parallel -

base.forward = base.forward_features
if torch.cuda.device_count() > 1:
    logger.info(f"Using {torch.cuda.device_count()} GPUs!")
    base = torch.nn.DataParallel(base)
embeddings = base(input)

I am hit with a RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper__cudnn_convolution)

Same thing if done after parallelizing,

if torch.cuda.device_count() > 1:
    logger.info(f"Using {torch.cuda.device_count()} GPUs!")
    base = torch.nn.DataParallel(base)
base.module.forward = base.module.forward_features
embeddings = base(input)