Pytorch geometric (temporal) with lightning

Hi, I am pretty new to deep learning let alone geometric deep learning. Nonetheless, I would prefer to start with some best practices from the beginning - such as using lightning with PyTorch. However, I have some trouble converting the temporal graph-specific structure of the training loop to lightning. So far, it is really unclear for me how to manually iterate the snapshots.

Furthermore, PyTorch geometric temporal seems to utilize a concept of temporal snapshots (!= batch size) where they assume every snapshot fully fits into memory.

from tqdm import tqdm
model = RecurrentGCN(node_features = 4) # chickenpox
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in tqdm(range(200)):
    cost = 0
    for time, snapshot in enumerate(train_dataset):
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        cost = cost + torch.mean((y_hat-snapshot.y)**2)
    cost = cost / (time+1)

pytorch_geometric_temporal/ at a13ea7876525ed9ba7c48ec69408024346eaded3 · benedekrozemberczki/pytorch_geometric_temporal · GitHub seems to already hold the answer