Hi, I am pretty new to deep learning let alone geometric deep learning. Nonetheless, I would prefer to start with some best practices from the beginning - such as using lightning with PyTorch. However, I have some trouble converting the temporal graph-specific structure of the training loop to lightning. So far, it is really unclear for me how to manually iterate the snapshots.
Furthermore, PyTorch geometric temporal seems to utilize a concept of temporal snapshots (!= batch size) where they assume every snapshot fully fits into memory.
from tqdm import tqdm model = RecurrentGCN(node_features = 4) # chickenpox model optimizer = torch.optim.Adam(model.parameters(), lr=0.01) model.train() for epoch in tqdm(range(200)): cost = 0 for time, snapshot in enumerate(train_dataset): y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) cost = cost + torch.mean((y_hat-snapshot.y)**2) cost = cost / (time+1) cost.backward() optimizer.step() optimizer.zero_grad()