IndexError: pop from empty list in grad_sample_module.py for opacus version > 0.9

Hi,

Im using Opacus to make CT-GAN (GitHub - sdv-dev/CTGAN: Conditional GAN for generating synthetic tabular data.) differntial private.
There is already an implementation who does this: (smartnoise-sdk/dpctgan.py at main · opendp/smartnoise-sdk · GitHub)
However, they use an older version of opacus (v0.9) and CTGAN(v.0.2.2.dev1).
I used their method to make the newest version of CTGAN differential private with the newest opacus version. Unfortunatly i run into the following error:

.../CTGAN_DP/DP_CTGAN.py", line 309, in fit
    loss_d.backward()
  File ".../lib/python3.7/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File ".../lib/python3.7/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
  File ".../python3.7/site-packages/opacus/grad_sample/grad_sample_module.py", line 197, in capture_backprops_hook
    module, backprops, loss_reduction, batch_first
  File ".../python3.7/site-packages/opacus/grad_sample/grad_sample_module.py", line 234, in rearrange_grad_samples
    A = module.activations.pop()
IndexError: pop from empty list

The function in question is:

 def rearrange_grad_samples(
        self,
        module: nn.Module,
        backprops: torch.Tensor,
        loss_reduction: str,
        batch_first: bool,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Rearrange activations and grad_samples based on loss reduction and batch dim

        Args:
            module: the module for which per-sample gradients are computed
            backprops: the captured backprops
            loss_reduction: either "mean" or "sum" depending on whether backpropped
            loss was averaged or summed over batch
            batch_first: True is batch dimension is first
        """
        if not hasattr(module, "activations"):
            raise ValueError(
                f"No activations detected for {type(module)},"
                " run forward after add_hooks(model)"
            )

        batch_dim = 0 if batch_first or type(module) is LSTMLinear else 1

        if isinstance(module.activations, list):
                A = module.activations.pop()
        else:
            A = module.activations

        if not hasattr(module, "max_batch_len"):
            # For packed sequences, max_batch_len is set in the forward of the model (e.g. the LSTM)
            # Otherwise we infer it here
            module.max_batch_len = _get_batch_size(module, A, batch_dim)

        n = module.max_batch_len
        if loss_reduction == "mean":
            B = backprops * n
        elif loss_reduction == "sum":
            B = backprops
        else:
            raise ValueError(
                f"loss_reduction = {loss_reduction}. Only 'sum' and 'mean' losses are supported"
            )

        # No matter where the batch dimension was, .grad_samples will *always* put it in the first dim
        if batch_dim != 0:
            A = A.permute([batch_dim] + [x for x in range(A.dim()) if x != batch_dim])
            B = B.permute([batch_dim] + [x for x in range(B.dim()) if x != batch_dim])

        return A, B

This does not happen with opacus v.09.
I investigated and found that module.activations is popped until only an empty list for the first module is left and then produces this error.

I hacked my way around this issue as follows:

  if isinstance(module.activations, list):
            #print(len(module.activations))
            if len(module.activations) > 1:
                A = module.activations.pop()
            else:
                A = module.activations[0]
        else:
            A = module.activations

Meaning, if module.activations is left with one element instead of an empty list at least training works.
My question is: Am I breaking anything important doing this and could this be a potential subcase which was not accounted for.
Or do I have to change something else in the model ?
The model I try to train with opacus is basically just a composition of n*(nn.Liner nn.Relu nn.Dropout) which should be fine I think.

Thanks for a reply :slight_smile:
Have a great day.

Thanks for flagging this. Your fix will probably lead to incorrect gradient computations (and probably break privacy guarantees as well). Normally, at each forward the activations gets pushed to the module.activations list, and they get popped in the backward. Popping from an empty list indicates that there is one more backward pass than forward pass. I am not sure why the list gets empty in your case, do you mind sharing a minimal reproducing example?

I am trying to do the exact same thing (run CTGAN in a differentially private manner using opacus). I have the original code of CTGAN (same as shown above) but just changed the optimizerD by attaching PrivacyEngine to it (no other changes).

And I get the exact same error at loss_d.backward().
When I print loss_d just before the backward call, I get this tensor(-0.0344, grad_fn=<NegBackward>).
I do not see any other backward call before this step, so am puzzled as to why this is happening.

@shaanchandra and @knilox It would be helpful if you could share with us a minimal reproducible example (for example, in Colab) with CTGAN so we can take a closer look at why your module.activations gets empty. We’ll be happy to take a look at this.

@shaanchandra @alexandresablayrolle @sayanghosh,
Thanks for all your replies and sorry for not responding in a timely manner.
I will beginn creating a minimal example now and post it here soon.

@sayanghosh @alexandresablayrolle
here is my colab example Google Colab.

Again thanks for looking into this. Hope this is enough.

1 Like

Hi @knilox thanks for the example, we are looking at it now and will provide an update soon.

Hello again,
After digesting the tip from @alexandresablayrolle about multiple backward passes, I think I now have localized the root cause of the problem. Not hard to find after that tip tbh.
Nevertheless, i was a little confused, because it indicated a problem with the unmodified model aswell. But training the unmodified model works just fine.

The problem is the gradient penalty, which is added to the loss function of the discriminator.
The penalty is added to enforce a soft 1-Lipschitz constraint on the gradient norms to stabilize convergence with the Wasserstein Distance. Meaning, the model is regularized to have gradient norms of 1. The other way of enforcing a 1-Lipschitz constraint is to clip the gradient norms at 1.
Usually the gradient penalty is preferred, because it performs much better.
However, as you know for DP we have to clip the gradient norms anyway :skull:

So the easiest fix for the problem is just removing pen.backward(retain_graph=True) and set max_grad_norm = 1 and we should be fine except the potential heavy loss in utility.
However, I am wondering if gradient penalty still makes sense in this situation and set max_grad_norm>1, e,g: 2, and let the model enforce grad norms of 1 itself.
The reason, i feel uncomfortable removing all the features of CT-GAN. But if that is the price for privacy i guess we have to pay it.

                    #pen = discriminator.calc_gradient_penalty(
                        #real_cat, fake_cat, self._device, self.pac)
                    loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
                    
                    optimizerD.zero_grad()
                    #pen.backward(retain_graph=True)
                    loss_d.backward()
                    optimizerD.step()

In addition to that, as far as i understood using retrain_graph=True should be the same as

                    loss_d = -(torch.mean(y_real) - torch.mean(y_fake)) + pen 
                    optimizerD.zero_grad()
                    loss_d.backward()
                    optimizerD.step()

But this also produces the same error and i dont understand why.
If the regularization is truly the problem my intuition tells me that this should be supported. However, maybe this is intended such that silly constraints are note enforced.

PS: after finding this i reviewed the code of smartnoise-sdk (second link) again and i found they also removed gradient penalty. But i missed it. Also they do not mention it in their paper.

For the case you tried out with pen not undergoing a backward pass, we still add it to the total loss function and thus to the computation graph. Generation of y_fake and y_real always take two forward passes, and thus two stack pushes for each layer. During computation of the gradient penalty, we have a backward pass (because of calling autograd.grad) and a forward pass (note the self(interpolates) which calls the network). When backward() is called on loss_d then backward() is expected to be called three times, as three corresponding forward passes were observed. It did not take into account the fact that an additional backward() got called during the loss computation. We subsequently have more backward passes than forwards, and so the activation stack empties out for each layer. This might explains why if you add pen or just do pen.backward() the error arises - it is due to an additional grad computation which was not taken into account.

1 Like

For the problem of supporting gradient regularization style techniques (for ex. in CTGAN or Wasserstein GAN), or more specifically your case the autograd.grad is done for purposes of estimation of the gradient regularizers and not training, so we should not need to go the route of per-sample gradient estimation, clipping and noise addition which DP would do here. In contrast we’d need that during the pen.backward().

1 Like

Hi @knilox and @sayanghosh ,

Thank you for the discussions. The working is very clear to me now.
However, I am not clear as to what is the consensus that is reached here to solve this specific problem?
Do we remove the gradient penalty term completely from the code?
If we keep it then how do we do the DP in this case using opacus?

Again, I understand what is discussed and it all makes sense. But what is the prescribed way forward is not clear to me.

@shaanchandra From the discussion above, it seems that there are two potential solutions:

  1. As @knilox mentioned above, remove the gradient regularization loss entirely - this may seem that we are not faithfully reproducing the original CTGAN, however DP eventually clips the gradients to bound SGD sensitivity, automatically introducing some regularization. Further, it appears that the smartnoise-sdk implementation which is being followed here already does this.
  2. Keep the gradient regularization loss but do not compute per-sample gradients during the call to autograd.grad (basically keep the backward hook disabled during that time) as it is for estimation of gradients and not for updating the network. We’ve tested this out and it works initially, however it seems to be creating a new issue during the .step() where it complains that the norms are of unequal length.
    So our recommendation is do (1). Option (2)'s error is probably unrelated and is also a fix.
2 Likes

@sayanghosh I have one final question wrt. to differential privacy, that is unrelated to the original post.
CTGAN uses an additional penalty term to the generator loss, which is the cross entropy loss of how many conditions are fulfilled by the generator in a batch.
To calculate this loss, we need the univariate distribution of all categorical attributes in the training data.
Thus, the generator technically depends directly on the training data making it not fully differential private. The other input depends on the discriminator, but that is differential private because of opacus.

I assume that’s why the smartnoise sdk simply removes it. But I currently think that this is not necessary.
And my question here is: Are there flaws in the following arguments ?

1: We Could make the calculation of the penalty term differential private and then have a differential private discriminator and a differential private penalty term. Thus, the generator loss is differential private and we have a differential private generator after training.

2:If we assume the univariate distribution of categorical attributes to be public knowledge, which is commonly assumed. We don’t have to make the penalty term satisfy differential privacy and can still call the final generator differential private.

Thanks again :slight_smile:
Your help is much appreciated!!

1 Like

Hi, I am interested in the part "basically keep the backward hook disabled " as I also meet a similar problem for multiple backward calling, could you mind give a bit more details here? thanks!