Integrating Opacus for a custom pytorch model raises errors but works fine on its own

I am trying to run a Transformer model with opacus, but running into the following error that relates to the custom layers. This does not arise when I train the model without privacy.
In general I understand that sometimes a separate grad sampler needs to be registered in case opacus cannot handle certain layers. But that does not seem to be the case because I do not have the “Layer incompatible/not recognized error”.
I also know that opacus has grad samplers for basic layers i.e. nn.linear, layernorm etc. which the architecture is composed of. But could it still be something opacus can’t handle? Or is it the way the architecture was implemented? What are some things to note when implementing custom models like this using opacus? Looking at the traceback, I think its related to backward pass at loss.backward().

FULL TRACEBACK:

UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "
Traceback (most recent call last):
File “main.py”, line 283, in
best_epoch = train_one_dataset(
File “main.py”, line 73, in train_one_dataset
train_loss, train_accuracy, train_auc = train(
File “/media/OS/code-exp/KT/knowledge-tracing-bert-transformers/context-aware-transformer-priv (grad-sampler)/run.py”, line 96, in train
loss.backward()
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/_tensor.py”, line 522, in backward
torch.autograd.backward(
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/autograd/init.py”, line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 71, in call
return self.hook(module, *args, **kwargs)
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/opacus/grad_sample/grad_sample_module.py”, line 337, in capture_backprops_hook
grad_samples = grad_sampler_fn(module, activations, backprops)
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/opacus/grad_sample/functorch.py”, line 58, in ft_compute_per_sample_gradient
per_sample_grads = layer.ft_compute_sample_grad(
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/_functorch/apis.py”, line 188, in wrapped
return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/_functorch/vmap.py”, line 278, in vmap_impl
return _flat_vmap(
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/_functorch/vmap.py”, line 44, in fn
return f(*args, **kwargs)
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/_functorch/vmap.py”, line 391, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/_functorch/apis.py”, line 363, in wrapper
return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py”, line 1295, in grad_impl
results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/_functorch/vmap.py”, line 44, in fn
return f(*args, **kwargs)
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py”, line 1256, in wrapper
output = func(*args, **kwargs)
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py”, line 489, in _fn
return fn(*args, **kwargs)
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/opacus/grad_sample/functorch.py”, line 34, in compute_loss_stateless_model
output = flayer(params, batched_activations)
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1520, in _call_impl
return forward_call(*args, **kwargs)
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/_functorch/make_functional.py”, line 342, in forward
return self.stateless_model(*args, **kwargs)
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/home/miniconda3/envs/kt-transformers/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1520, in _call_impl
return forward_call(*args, **kwargs)
TypeError: forward() missing 4 required positional arguments: ‘k’, ‘v’, ‘mask’, and ‘zero_pad’

architecture:
GitHub - arghosh/AKT → akt.py

Changes I made:
I reproduced it exactly as it is and added the following lines in main.py:
model = load_model(params)
#optimizer = torch.optim.Adam(model.parameters(), lr=params.lr, betas=(0.9, 0.999), eps=1e-8)
optimizer = torch.optim.SGD(model.parameters(), lr=params.lr, momentum=0.9)
print(“\n”)
#opacus_utils.register_grad_sampler(MonotonicMultiHeadAttention)(compute_monotonic_attention_grad_sample)
batch_size = 24
accountant = RDPAccountant()
dp_model = GradSampleModule(model)
dp_optimizer = DPOptimizer(
optimizer=optimizer,
noise_multiplier=1.0, # same as make_private arguments
max_grad_norm=1.0, # same as make_private arguments
expected_batch_size=batch_size # if you’re averaging your gradients, you need to know the denominator
)
sample_rate = batch_size/len(train_q_data)
print(sample_rate)
dp_optimizer.attach_step_hook(
accountant.get_optimizer_hook_fn(
# this is an important parameter for privacy accounting. Should be equal to batch_size / len(dataset)
sample_rate=sample_rate
)
)
#training
all_train_loss = {}
all_train_accuracy = {}
all_train_auc = {}
all_valid_loss = {}
all_valid_accuracy = {}
all_valid_auc = {}
best_valid_auc = 0

for idx in range(params.max_iter):
    # Train Model
    train_loss, train_accuracy, train_auc = train(
         dp_model, params, dp_optimizer, train_q_data, train_qa_data, train_pid, accountant,  label='Train')
    # Validation step
    valid_loss, valid_accuracy, valid_auc = test(
        dp_model,  params, dp_optimizer, valid_q_data, valid_qa_data, valid_pid, label='Valid' )

**I am providing all the details to reproduce this:
opacus version: 1.4.0
pytorch: 2.2.1+CUDA 11.8
pandas, tqdm all latest versions work fine

could multiple forward passes in the model be an issue?