I found a weird computation cost when I was trying to compute Hessian vector product of DCGAN.
The Hessian vector product for DCGAN costs 100x more than the other GAN , even though it contains less parameters. I’m confused because the problem seems not caused by one single module since the other GAN contains linear, pooling, convolution while DCGAN only has convolution. Could anyone figure out what’s going on here?
I use CUDA events to measure time cost of backward of each model and Hessian vector product. And here is the result.
=====Test1=====
Discriminator backward : 10.1550083ms
Generator backward : 6.1214719ms
Hessian vector product : 13.5014400ms
=====Test2=====
Discriminator backward : 18.8948479ms
Generator backward : 31.2074242ms
Hessian vector product : 1206.1020508ms
The batchnorms and activation functions don’t make a difference so I removed all of them to make the code clear. The math form of hessian vector product here is
$\frac{\partial^2f}{\partial G\partial D} \frac{\partial f}{\partial D}$
class DC_generator(nn.Module):
def __init__(self, z_dim=100, channel_num=3, feature_num=64):
super(DC_generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(z_dim, feature_num * 8, kernel_size=4, stride=1, padding=0, bias=True),
# (feature_num * 8) * 4x4
nn.ConvTranspose2d(feature_num * 8, feature_num * 4, kernel_size=4, stride=2, padding=1, bias=True),
# (feature_num * 4) * 8x8
nn.ConvTranspose2d(feature_num * 4, feature_num * 2, kernel_size=4, stride=2, padding=1, bias=True),
# (feature_num * 2) * 16x16
nn.ConvTranspose2d(feature_num * 2, feature_num, kernel_size=4, stride=2, padding=1, bias=True),
# (feature_num * 2) * 32x32
nn.ConvTranspose2d(feature_num, channel_num, kernel_size=4, stride=2, padding=1, bias=True),
# channel_num * 64x64
)
def forward(self, input):
return self.main(input)
class DC_discriminator(nn.Module):
def __init__(self, channel_num=3, feature_num=64):
super(DC_discriminator, self).__init__()
self.main = nn.Sequential(
# channel_num * 64x64
nn.Conv2d(channel_num, feature_num, kernel_size=4, stride=2, padding=1, bias=True),
# (feature_num) * 32x32
nn.Conv2d(feature_num, feature_num * 2, kernel_size=4, stride=2, padding=1, bias=True),
# (feature_num * 2) * 16x16
nn.Conv2d(feature_num * 2, feature_num * 4, kernel_size=4, stride=2, padding=1, bias=True),
# (feature_num * 4) * 8x8
nn.Conv2d(feature_num * 4, feature_num * 8, kernel_size=4, stride=2, padding=1, bias=True),
# (feature_num * 8) * 4x4
nn.Conv2d(feature_num * 8, 1, kernel_size=4, stride=1, padding=0, bias=True),
# feature_num * 16x16
)
def forward(self, input):
return self.main(input)
class dc_D(nn.Module):
def __init__(self):
super(dc_D, self).__init__()
self.conv = nn.Sequential(
# 3 * 32x32
nn.Conv2d(in_channels=3, out_channels=128, kernel_size=5, stride=1),
nn.LeakyReLU(0.01),
nn.MaxPool2d(2,2),
# 32 * 14x14
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=5, stride=1),
nn.LeakyReLU(0.01),
nn.MaxPool2d(2,2)
# 64 * 5x5
)
self.fc = nn.Sequential(
nn.Linear(1600 * 2, 1024),
nn.LeakyReLU(0.01),
nn.Linear(1024, 1)
)
def forward(self, x):
x = self.conv(x)
x = x.view(x.shape[0], -1)
return self.fc(x)
class dc_G(nn.Module):
def __init__(self, z_dim=96):
super(dc_G, self).__init__()
self.fc = nn.Sequential(
nn.Linear(z_dim, 1024),
nn.Linear(1024, 8 * 8 * 128),
)
self.convt = nn.Sequential(
nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=4, stride=2, padding=1)
)
def forward(self, x):
x = self.fc(x)
x = x.view(x.shape[0], 128, 8, 8)
return self.convt(x)
def hvptest(D, G, z):
criterion = nn.BCEWithLogitsLoss()
d_fake = D(G(z))
loss = criterion(d_fake, torch.zeros(d_fake.shape, device=device))
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
grad_d = torch.autograd.grad(loss, D.parameters(), create_graph=True, retain_graph=True)
end.record()
torch.cuda.synchronize(device=device)
print('Discriminator backward : %.7fms' % (start.elapsed_time(end)))
grad_d_vec = torch.cat([g.contiguous().view(-1) for g in grad_d])
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
grad_g = torch.autograd.grad(loss, G.parameters(), create_graph=True, retain_graph=True)
end.record()
torch.cuda.synchronize(device=device)
print('Generator backward : %.7fms' % (start.elapsed_time(end)))
grad_g_vec = torch.cat([g.contiguous().view(-1) for g in grad_g])
# print('Discriminator parameter number: %d' % grad_d_vec.numel())
# print('Generator parameter number: %d' % grad_g_vec.numel())
vec_d = grad_d_vec.clone().detach()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
hvp_d = torch.autograd.grad(grad_d_vec, G.parameters(), grad_outputs=vec_d, retain_graph=True)
end.record()
torch.cuda.synchronize(device=device)
print('Hessian vector product : %.7fms' % (start.elapsed_time(end)))
if __name__ == '__main__':
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
batch_size = 128
print('=====Test1=====')
D = dc_D().to(device)
G = dc_G(z_dim=96).to(device)
img_cifar = (torch.rand((batch_size, 3, 32, 32), device=device) - 0.5) / 0.5
z_mnist = torch.randn((batch_size, 96), device=device)
hvptest(D=D, G=G, z=z_mnist)
print('=====Test2=====')
D = DC_discriminator(channel_num=3, feature_num=64).to(device)
G = DC_generator(z_dim=100, channel_num=3, feature_num=64).to(device)
z_celeba = torch.randn((batch_size, 100, 1, 1), device=device)
hvptest(D=D, G=G, z=z_celeba)
PS: Pytorch version: 1.1.0, GPU: Tesla V100