How to use Opacus with TemporalGNN e.g. A3TGCN2?

Hello everyone,

Can you help me to train a temporal GNN with Opacus?
I try to set up Differential Private learning on a Graph Temporal Neural Network, using opacus for the differential privacy and A3TGCN2 from the pytorch-gemoemtric-temporal library for the TemporalGNN part. In particular I try to combine the guides for opacus and for a3tgcn training.

To check if the resulting Model is opacus compliant, i test it with
the opacus ModuleValidator and got no errors. So I assume the model is compliant.

Unfortunately I still run into errors. Could you please help me understand where my errors come from, how I can fix them and optimally provide a minimal working example?

Here are the relevant parts of my code:

Using the privacy engine:

from opacus import PrivacyEngine
secure_mode = False privacy_engine = PrivacyEngine(secure_mode=secure_mode)  
# Privacy engine hyper-parameters 
MAX_GRAD_NORM = 1.5 
DELTA = 1e-5 
EPSILON = 50.0 
EPOCHS = 20  priv_model, priv_optimizer, priv_train_loader =
privacy_engine.make_private_with_epsilon(
    module=model,     
    optimizer=optimizer,     
    data_loader=train_loader,     
    max_grad_norm=MAX_GRAD_NORM,     
    target_delta=DELTA,     
    target_epsilon=EPSILON,     
    epochs=EPOCHS )

Attempting to train the Model:

model.train()

snapshot = next(iter(train_set))
edge_index = snapshot.edge_index.to(device)

for epoch in tqdm(range(EPOCHS)): 
    loss = 0
    loss_list = []
    for encoder_inputs, labels in priv_train_loader:
    
        y_hat = priv_model(encoder_inputs, edge_index)
        loss = criterion(y_hat, labels)
        loss.backward()
        priv_optimizer.step()
        priv_optimizer.zero_grad()

The error message:

ValueError: Poisson sampling is not compatible with grad accumulation. You need to call optimizer.step() after every forward/backward pass or consider using BatchMemoryManager

Attempt with BatchMemoryManager

for epoch in tqdm(range(EPOCHS)): 
loss = 0
loss_list = []
with BatchMemoryManager( data_loader=priv_train_loader, max_physical_batch_size=2, optimizer=optimizer) as new_data_loader:
    for encoder_inputs, labels in new_data_loader:
        y_hat = priv_model(encoder_inputs, edge_index)
        loss = criterion(y_hat, labels)
        loss.backward()
        priv_optimizer.step()
        priv_optimizer.zero_grad()

The second error message:

AttributeError: 'SGD' object has no attribute 'signal_skip_step'