Functorch activations with multiple inputs

I have a question regarding the use of functorch with Opacus GradSampleModule in the latest main branch code of opacus. I copied the code below and have commented my quesitons in the code.

def prepare_layer(layer, batch_first=True):
    """
    Prepare a layer to compute grad samples using functorch.
    The grad samples are computed by redoing the forward and
    backward passes on the functional version of the module.

    Args:
        layer: the layer to prepare
        batch_first: whether the input is batch_first or not
    """
    if len(list(layer.buffers())) > 0:
        raise NotImplementedError(
            "This layer has buffers and is not supported by Opacus"
        )
    if type(layer) is nn.EmbeddingBag:
        raise NotImplementedError("Functorch does not support EmbeddingBag yet")

    flayer, _ = make_functional(layer)

    def compute_loss_stateless_model(params, activations, backprops):
        if batch_first or type(layer) is RNNLinear:
            batched_activations = activations.unsqueeze(0) # HERE the activations are assumed to be a single tensor
            batched_backprops = backprops.unsqueeze(0)
        else:
            # If batch_first is False, the batch dimension is the second dimension
            batched_activations = activations.unsqueeze(1)
            batched_backprops = backprops.unsqueeze(1)

        output = flayer(params, batched_activations)
        loss = (output * batched_backprops).sum()
        return loss

    ft_compute_grad = grad(compute_loss_stateless_model)
    # Note that the vmap is done on the first dimension, regardless of batch_first
    # This is because the activations and backprops given by the GradSampleModule
    # are always batch_first=True
    layer.ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0, 0))


def ft_compute_per_sample_gradient(layer, activations, backprops):
    """
    Compute the per-sample gradient of the layer.
    Args:
        layer: the layer on which to compute the gradient
        activations: the input to the layer
        backprops: the  gradient of the loss w.r.t. outputs of the layer
    """
    parameters = list(layer.parameters(recurse=True))
    if not hasattr(layer, "ft_compute_sample_grad"):
        prepare_layer(layer)

    per_sample_grads = layer.ft_compute_sample_grad(
        parameters, activations[0], backprops  # HERE only the first element of stored activations is used. why?
    )

    ret = {}
    for i_p, p in enumerate(parameters):
        ret[p] = per_sample_grads[i_p]

    return ret

My question is why is the opacus assuming that the model can only receive a single input? Or I am mistaking what is being done at the moment? In ft_compute_per_sample_gradient function, the activations[0] is passed instead of the whole list of inputs defined by ‘activations’ which is previously stored on the foward hook. This causes a model with multiple inputs required to just fail on the forward pass. I use hte following solution:

# pass all inputs here
 per_sample_grads = layer.ft_compute_sample_grad(
        parameters, activations, backprops
    )

# unsqueeze all elements in list here and pass the list unrolled 
if batch_first or type(layer) is RNNLinear:
            batched_activations = [act.unsqueeze(0) for act in activations] 
            batched_backprops = backprops.unsqueeze(0)
        else:
            # If batch_first is False, the batch dimension is the second dimension
            batched_activations = [act.unsqueeze(1) for act in activations] 
            batched_backprops = backprops.unsqueeze(1)

Is this correct for a model taking multiple sequential arguments?