Hello, I’m trying to build a Generator for my GAN and this is what it looks like:
class Generator(nn.Module):
def __init__(self, latent_size, init_resolution=8, resolution=256, channels=3):
super(Generator,self).__init__()
#resolution = self.resolution
#channels = self.channels
self.cnn_torch = nn.Sequential()
self.cnn_torch.add_module("Linear1", nn.Linear(latent_size, 1024, bias=False))
self.cnn_torch.add_module("ReLu1", nn.ReLU(inplace=True))
self.cnn_torch.add_module("Linear2", nn.Linear(1024, 128 * init_resolution * init_resolution, bias=False))
self.cnn_torch.add_module("ReLu2", nn.ReLU(inplace=True))
self.cnn_torch.add_module("reshape", Reshape((128, init_resolution, init_resolution)))
crt_res = init_resolution
#upsample
while crt_res != resolution:
self.cnn_torch.add_module("Upsample", nn.UpsamplingNearest2d(scale_factor=2))
if crt_res < resolution/2:
self.cnn_torch.add_module("Upsample Conv2d", nn.Conv2d(crt_res, 256, (5, 5), padding=1, bias=False))
self.cnn_torch.add_module("Upsample ReLU", nn.ReLU(inplace=True))
else:
self.cnn_torch.add_module("Upsample Conv2d - else", nn.Conv2d(crt_res, 128, (5, 5), padding=1, bias=False))
crt_res = crt_res * 2
assert crt_res <= resolution,\
"Error: final resolution [{}] must equal i*2^n. Initial resolution i is [{}]. n must be a natural number.".format(resolution, init_resolution)
self.cnn_torch.add_module("Conv2d final", nn.Conv2d(crt_res, channels, (2,2), padding=1, bias=False))
self.cnn_torch.add_module("Tanh", nn.Tanh())
print(self.cnn_torch)
def forward(self, latent):
fake_image_from_latent = self.cnn_torch(latent)
return fake_image_from_latent
But I’m getting the following error when simply trying to test it (with no images):
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-30-ccb426d28f8b> in <cell line: 14>()
12
13 # Pass the latent vector through the generator
---> 14 fake_images = generator(latent_vector)
15
16 # Inspect the generated fake images
5 frames
<ipython-input-14-8cb072a2d5c1> in forward(self, x)
7
8 def forward(self, x):
----> 9 return x.view(self.shape)
TypeError: view() received an invalid combination of arguments - got (tuple), but expected one of:
* (torch.dtype dtype)
didn't match because some of the arguments have invalid types: (!tuple of (tuple,)!)
* (tuple of ints size)
didn't match because some of the arguments have invalid types: (!tuple of (tuple,)!)
I’m trying to test it this way to see if it outputs the correct tensor shape:
# Instantiate the Generator
latent_size = 100
init_resolution = 8
resolution = 256
channels = 3
generator = Generator(latent_size, init_resolution, resolution, channels)
# Create a random input tensor that represents the latent vector
batch_size = 10
latent_vector = torch.randn(batch_size, latent_size)
# Pass the latent vector through the generator
fake_images = generator(latent_vector)
# Inspect the generated fake images
print(fake_images.shape)
Could anyone help, please? Thanks in advance.