I have a question regarding the use of functorch with Opacus GradSampleModule in the latest main branch code of opacus. I copied the code below and have commented my quesitons in the code.

```
def prepare_layer(layer, batch_first=True):
"""
Prepare a layer to compute grad samples using functorch.
The grad samples are computed by redoing the forward and
backward passes on the functional version of the module.
Args:
layer: the layer to prepare
batch_first: whether the input is batch_first or not
"""
if len(list(layer.buffers())) > 0:
raise NotImplementedError(
"This layer has buffers and is not supported by Opacus"
)
if type(layer) is nn.EmbeddingBag:
raise NotImplementedError("Functorch does not support EmbeddingBag yet")
flayer, _ = make_functional(layer)
def compute_loss_stateless_model(params, activations, backprops):
if batch_first or type(layer) is RNNLinear:
batched_activations = activations.unsqueeze(0) # HERE the activations are assumed to be a single tensor
batched_backprops = backprops.unsqueeze(0)
else:
# If batch_first is False, the batch dimension is the second dimension
batched_activations = activations.unsqueeze(1)
batched_backprops = backprops.unsqueeze(1)
output = flayer(params, batched_activations)
loss = (output * batched_backprops).sum()
return loss
ft_compute_grad = grad(compute_loss_stateless_model)
# Note that the vmap is done on the first dimension, regardless of batch_first
# This is because the activations and backprops given by the GradSampleModule
# are always batch_first=True
layer.ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0, 0))
def ft_compute_per_sample_gradient(layer, activations, backprops):
"""
Compute the per-sample gradient of the layer.
Args:
layer: the layer on which to compute the gradient
activations: the input to the layer
backprops: the gradient of the loss w.r.t. outputs of the layer
"""
parameters = list(layer.parameters(recurse=True))
if not hasattr(layer, "ft_compute_sample_grad"):
prepare_layer(layer)
per_sample_grads = layer.ft_compute_sample_grad(
parameters, activations[0], backprops # HERE only the first element of stored activations is used. why?
)
ret = {}
for i_p, p in enumerate(parameters):
ret[p] = per_sample_grads[i_p]
return ret
```

My question is why is the opacus assuming that the model can only receive a single input? Or I am mistaking what is being done at the moment? In ft_compute_per_sample_gradient function, the activations[0] is passed instead of the whole list of inputs defined by ‘activations’ which is previously stored on the foward hook. This causes a model with multiple inputs required to just fail on the forward pass. I use hte following solution:

```
# pass all inputs here
per_sample_grads = layer.ft_compute_sample_grad(
parameters, activations, backprops
)
# unsqueeze all elements in list here and pass the list unrolled
if batch_first or type(layer) is RNNLinear:
batched_activations = [act.unsqueeze(0) for act in activations]
batched_backprops = backprops.unsqueeze(0)
else:
# If batch_first is False, the batch dimension is the second dimension
batched_activations = [act.unsqueeze(1) for act in activations]
batched_backprops = backprops.unsqueeze(1)
```

Is this correct for a model taking multiple sequential arguments?