class Discriminator(nn.Module):
def init(self):
super().init()
self.model=nn.Sequential(
nn.Linear(392*384,1024),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(1024,512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512,256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256,1),
nn.Sigmoid()
)
def forward(self,x):
x=x.view(x.size(0),672)
output=self.model(x)
return output
#generator
class Generator(nn.Module):
def init(self):
super().init()
self.model=nn.Sequential(
nn.Linear(1000,224),
nn.ReLU(),
nn.Linear(224,448),
nn.ReLU(),
nn.Linear(448,672),
nn.Tanh())
def forward(self,x):
x=x.view(x.size(0),224)
output=self.model(x)
return output
loss_function=nn.BCELoss()
optimizer_discriminator=torch.optim.Adam(discriminator.parameters())
optimizer_generator=torch.optim.Adam(generator.parameters())
discriminator=Discriminator()
generator=Generator()
#Training the model
batch_size=1000
num_epochs=10
for epoch in range(num_epochs):
for n,(real_samples,Labels) in enumerate(train_set):
real_samples=real_samples
real_sample_labels=torch.ones((batch_size,1))
latent_heat_samples=torch.randn((batch_size,224))
generated_samples=generator(latent_heat_samples)
generated_sample_label=torch.zeros((batch_size,1))
all_samples=torch.cat((real_samples,generated_samples))
all_sample_labels=torch.cat((real_sample_labels,generated_sample_label))
#training discriminator
optimizer_discriminator.zero_grad()
discriminator_samples=discriminator(all_samples)
loss_discriminator=loss_function(discriminator_samples,all_sample_labels)
loss_discriminator.backward()
optimizer_discriminator.step()
#training generator
optimizer_generator.zero_grad()
generator_samples=generator(latent_heat_samples)
generator_discriminator_sample=discriminator(generator_samples)
loss_generator=loss_function(generator_discriminator_samples,real_samples)
loss_generator.backward()
optimizer_generator.step()
#printing losses at each epoch
if n==batch_size-1:
print(f"Epoch:{epoch}, Loss D:{loss_discriminator}")
print(f"Epoch:{epoch}, Loss G:{loss_generator}")
Error as
RuntimeError Traceback (most recent call last)
in
10 latent_heat_samples=torch.randn((batch_size,224))
11
—> 12 generated_samples=generator(latent_heat_samples)
13 generated_sample_label=torch.zeros((batch_size,1))
14
~\anaconda3\envs\gan\lib\site-packages\torch\nn\modules\module.py in call(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
–> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
in forward(self, x)
13 def forward(self,x):
14 x=x.view(x.size(0),224)
—> 15 output=self.model(x)
16 return output
17
~\anaconda3\envs\gan\lib\site-packages\torch\nn\modules\module.py in call(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
–> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
~\anaconda3\envs\gan\lib\site-packages\torch\nn\modules\container.py in forward(self, input)
98 def forward(self, input):
99 for module in self:
–> 100 input = module(input)
101 return input
102
~\anaconda3\envs\gan\lib\site-packages\torch\nn\modules\module.py in call(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
–> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
~\anaconda3\envs\gan\lib\site-packages\torch\nn\modules\linear.py in forward(self, input)
85
86 def forward(self, input):
—> 87 return F.linear(input, self.weight, self.bias)
88
89 def extra_repr(self):
~\anaconda3\envs\gan\lib\site-packages\torch\nn\functional.py in linear(input, weight, bias)
1368 if input.dim() == 2 and bias is not None:
1369 # fused op is marginally faster
-> 1370 ret = torch.addmm(bias, input, weight.t())
1371 else:
1372 output = input.matmul(weight.t())
RuntimeError: size mismatch, m1: [1000 x 224], m2: [1000 x 224] at C:\Users\builder\AppData\Local\Temp\pip-req-build-e5c8dddg\aten\src\TH/generic/THTensorMath.cpp:136
Please suggest me to solve this error, I am new to Pytorch. Thanks in Advance.