Target_color = torch.zeros(1,1,3, dtype=torch.float)
Target_color[:,:,2] = 255
Target_color = Target_color.to(DEVICE)
below is my model
class Discriminator(nn.Module):
def init(self,num_classes,image_size):
super(Discriminator, self).init()
self.image_size = image_size
self.main = nn.Sequential(
nn.Conv2d(6, 64, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, padding=1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, 4, padding=1),
nn.Sigmoid()
)
# self.label_condition_disc = nn.Sequential(nn.Embedding(1, 3),nn.Linear(3, 3*256*256),
# nn.ReLU(True))
self.embed = nn.Embedding(num_classes,self.image_size*self.image_size)
def forward(self, inputs):
x,label = inputs
label_output = self.embed(label)
label_output = label_output.view(-1,3,self.image_size,self.image_size)
concat = torch.cat((x, label_output), dim=1)
x = self.main(concat)
x = F.avg_pool2d(x, x.size()[2:])
x = torch.flatten(x, 1)
return x