'GradSampleModule' object has no attribute

Hello Opacus Team.

After wrapping a nn.Module with GradSampleModule, how can I access the methods (e.g., def init_weights (self) :) of the model?

Specifically, I want to use Opacus without PrivacyEngine to privatize the Discriminator of CT-GAN, however, the forward of CT-GAN fails to call calc_gradient_penalty.

class Discriminator(Module):

    """Discriminator for the CTGANSynthesizer."""

    def __init__(self, input_dim, discriminator_dim, pac=10):

        super(Discriminator, self).__init__()

        dim = input_dim * pac

        self.pac = pac

        self.pacdim = dim

        seq = []

        for item in list(discriminator_dim):

            seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)]

            dim = item

        seq += [Linear(dim, 1)]

        self.seq = Sequential(*seq)

    def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lambda_=10):

        """Compute the gradient penalty."""

        alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device)

        alpha = alpha.repeat(1, pac, real_data.size(1))

        alpha = alpha.view(-1, real_data.size(1))

        interpolates = alpha * real_data + ((1 - alpha) * fake_data)

        disc_interpolates = self(interpolates)

        gradients = torch.autograd.grad(

            outputs=disc_interpolates, inputs=interpolates,

            grad_outputs=torch.ones(disc_interpolates.size(), device=device),

            create_graph=True, retain_graph=True, only_inputs=True

        )[0]

        gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1

        gradient_penalty = ((gradients_view) ** 2).mean() * lambda_

        return gradient_penalty

    def forward(self, input_):

        """Apply the Discriminator to the `input_`."""

        assert input_.size()[0] % self.pac == 0

        return self.seq(input_.view(-1, self.pacdim))

Best regards!

Hi,

Thanks for reaching out. Unfortunately, Opacus does not yet support advanced computation graph manipulations (such as torch.autograd.grad()). We are currently looking at functorch to potentially support that kind of operations in the future.

Thank you for your quick reply. I guess I’ll try another GAN variant then.