What kind of error do you get?
This should work:
class MyModel(nn.Module):
def __init__(self, split_gpus):
self.large_submodule1 = ...
self.large_submodule2 = ...
self.split_gpus = split_gpus
if split_gpus:
self.large_submodule1.cuda(0)
self.large_submodule1.cuda(1)
def forward(self, x):
x = self.large_submodule1(x)
if split_gpus:
x = x.cuda(1) # P2P GPU transfer
return self.large_submodule2(x)