I am wondering if in Graph Conv Nets, we can apply the same model multiple times in different hops of the graph, but move the model to different devices in different hops?

Hi,

I have the following code, which I try to apply the model multiple times to the input, but to avoid the cuda memory issue, i have to move the model between devices. I am getting an error and I am wondering if it is something possible? Here is a simple code which is easy to reproduce:

import os.path as osp

import torch
import torch.nn as nn
import torch.nn.functional as F
#from torch_cluster import random_walk
from sklearn.linear_model import LogisticRegression

import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Planetoid
from torch_geometric.data import NeighborSampler 

EPS = 1e-15

dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]




train_loader = NeighborSampler(data.edge_index, sizes=[10, 10], batch_size=256,
                               shuffle=True, num_nodes=data.num_nodes)


class SAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.convs = SAGEConv(in_channels, hidden_channels)

    def forward(self, x, adjs):
        for i, (edge_index, _, size) in enumerate(adjs):
            if i ==0:
                device = torch.device('cuda:0')
                self.convs = self.convs.to(device)
                x = x.to(device)
                edge_index = edge_index.to(device)
            else:
                device = torch.device('cuda:1')
                self.convs = self.convs.to(device)
                x = x.to(device)
                edge_index = edge_index.to(device)
                
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs((x, x_target), edge_index)
        return x


model = SAGE(data.num_node_features, hidden_channels=data.num_node_features)
model = model
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
x = data.x


def train():
    model.train()

    for batch_size, n_id, adjs in train_loader:
        # `adjs` holds a list of `(edge_index, e_id, size)` tuples.
        #adjs = [adj.to(device) for adj in adjs]
        optimizer.zero_grad()

        out = model(x[n_id], adjs)
        loss = torch.mean(out)

        loss.backward()
        optimizer.step()

train()

When i try to do that, I get this error:

File "test.py", line 79, in <module>
    train()
  File "test.py", line 76, in train
    loss.backward()
  File "/opt/conda/lib/python3.8/site-packages/torch/tensor.py", line 245, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/__init__.py", line 145, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

So, is it possible to train the model while i move the model between devices without getting this error? Thanks.