Not sure if this layer is DP-friendly

Hi, I do have an implementation of PG-GANs in hand and the discriminator of this model has a so-called Minibatch-STD layer. Opacus seems to validate the layer w/o any problems. However, what this layer actually does is that it introduces a new statistic as an extra dimension to every sample and this statistic is dependent on other samples on the batch.
Ex.: when inputted tensor([[1,2,3], [3,4,5]]) it outputs tensor([[1,2,3, 1.4142], [3,4,5, 1.4142]]).

The layer seems like this:

class Minibatch_std(nn.Module):
	def __init__(self):
		super().__init__()
	def forward(self, x):
		size = list(x.size())
		size[1] = 1
		
		std = torch.std(x, dim=0)
		mean = torch.mean(std)
		return torch.cat((x, mean.repeat(size)),dim=1)

and is used here:

class D_Block(nn.Module):
	def __init__(self, in_ch, out_ch, initial_block=False):
		super().__init__()

		if initial_block:
			self.minibatchstd = Minibatch_std()
			self.conv1 = EqualizedLR_Conv2d(in_ch+1, out_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
			self.conv2 = EqualizedLR_Conv2d(out_ch, out_ch, kernel_size=(4, 4), stride=(1, 1))
			self.outlayer = nn.Sequential(
									nn.Flatten(),
									nn.Linear(out_ch, 1)
									)
		else:			
			self.minibatchstd = None
			self.conv1 = EqualizedLR_Conv2d(in_ch, out_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
			self.conv2 = EqualizedLR_Conv2d(out_ch, out_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
			self.outlayer = nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2))

		self.relu = nn.LeakyReLU(0.2)
		nn.init.normal_(self.conv1.weight)
		nn.init.normal_(self.conv2.weight)
		nn.init.zeros_(self.conv1.bias)
		nn.init.zeros_(self.conv2.bias)
	
	def forward(self, x):
		if self.minibatchstd is not None:
			x = self.minibatchstd(x)
		
		x = self.conv1(x)
		x = self.relu(x)
		x = self.conv2(x)
		x = self.relu(x)
		x = self.outlayer(x)
		return x

Is this layer DP-friendly?

Note: the implementation is taken from GitHub - Maggiking/PGGAN-PyTorch: A pytorch implementation of Progressive Growing GAN.

Thanks for reporting this. I believe your concern is valid. When dealing with layers for which per-sample gradient is not implemented, Opacus computes it based on “best-effort” basis. Mixing information across a batch is definitely not DP-friendly unfortunately.

1 Like