Hi,
I am running StarGAN code [https://github.com/yunjey/stargan] and modifying the discriminator a bit.
class Discriminator(nn.Module):
"""Discriminator network with PatchGAN."""
def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
super(Discriminator, self).__init__()
layers = []
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01))
curr_dim = conv_dim
for i in range(1, repeat_num):
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01))
curr_dim = curr_dim * 2
kernel_size = int(image_size / np.power(2, repeat_num))
self.main = nn.Sequential(*layers)
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False)
self.c3 = nn.Linear(2048, 512)
self.c4 = nn.ReLU6()
self.c5 = nn.Linear(512, 128)
def forward(self, x):
h = self.main(x)
out_src = self.conv1(h)
out_cls = self.conv2(h)
out_feat = h.view(h.size(0), -1)
print(h.shape)
out_feat = self.c3(out_feat)
out_feat = self.c4(out_feat)
out_feat = self.c5(out_feat)
return out_feat.squeeze(), out_src, out_cls.view(out_cls.size(0), out_cls.size(1))
In forward, I print the output shape of the conv blocks print(h.shape)
and it yield
torch.Size([16]) # dataloader batch size
torch.Size([16, 2048, 1, 1])
torch.Size([16, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([16, 2048, 1, 1])
torch.Size([16]) # Dataloader batch size
torch.Size([16, 2048, 1, 1])
torch.Size([16, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
torch.Size([7, 2048, 1, 1])
...
Why there are batch size of 7 in here? Batch size is defined as 16. I also printed the dataloader batch size, and it = 16.
Thank you for your time.