I am wondering if in Graph Conv Nets, we can apply the same model multiple times in different hops of the graph, but move the model to different devices in different hops?


I have the following code, which I try to apply the model multiple times to the input, but to avoid the cuda memory issue, i have to move the model between devices. I am getting an error and I am wondering if it is something possible? Here is a simple code which is easy to reproduce:

import os.path as osp

import torch
import torch.nn as nn
import torch.nn.functional as F
#from torch_cluster import random_walk
from sklearn.linear_model import LogisticRegression

import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Planetoid
from torch_geometric.data import NeighborSampler 

EPS = 1e-15

dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]

train_loader = NeighborSampler(data.edge_index, sizes=[10, 10], batch_size=256,
                               shuffle=True, num_nodes=data.num_nodes)

class SAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        self.convs = SAGEConv(in_channels, hidden_channels)

    def forward(self, x, adjs):
        for i, (edge_index, _, size) in enumerate(adjs):
            if i ==0:
                device = torch.device('cuda:0')
                self.convs = self.convs.to(device)
                x = x.to(device)
                edge_index = edge_index.to(device)
                device = torch.device('cuda:1')
                self.convs = self.convs.to(device)
                x = x.to(device)
                edge_index = edge_index.to(device)
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs((x, x_target), edge_index)
        return x

model = SAGE(data.num_node_features, hidden_channels=data.num_node_features)
model = model
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
x = data.x

def train():

    for batch_size, n_id, adjs in train_loader:
        # `adjs` holds a list of `(edge_index, e_id, size)` tuples.
        #adjs = [adj.to(device) for adj in adjs]

        out = model(x[n_id], adjs)
        loss = torch.mean(out)



When i try to do that, I get this error:

File "test.py", line 79, in <module>
  File "test.py", line 76, in train
  File "/opt/conda/lib/python3.8/site-packages/torch/tensor.py", line 245, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/__init__.py", line 145, in backward
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

So, is it possible to train the model while i move the model between devices without getting this error? Thanks.