Passing image as a condition to Discriminator to make it conditional

Hi team,

How can we pass a image to a Discriminator block to condition the same .

Below is my image

Main image - torch.randn(1,3,256,256)
Conditonal image - torch.randn(1,3,256,256)
Target_colors = torch.zeros(1,1,3,dtype=torch.float)
Target_colors[:,:,0] = 255

#Below is the way i am trying -

self.label_condition_color = nn.Sequential(nn.Embedding(300,3),nn.Linear(3,256*256),
nn.ReLU(True))

self.embed_sketch = nn.Embedding(1,self.image_size*self.image_size)

i need to concat both to main image above

concat = torch.cat((x, label_output), dim=1)

Actual Code i am using -

class Discriminator(nn.Module):
def init(self,image_size):
super(Discriminator, self).init()
self.image_size = image_size

    self.main = nn.Sequential(
        nn.Conv2d(6, 64, 4, stride=2, padding=1),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(64, 128, 4, stride=2, padding=1),
        nn.InstanceNorm2d(128),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(128, 256, 4, stride=2, padding=1),
        nn.InstanceNorm2d(256),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(256, 512, 4, padding=1),
        nn.InstanceNorm2d(512),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(512, 1, 4, padding=1),
        nn.Sigmoid()
    )
    self.label_condition_color = nn.Sequential(nn.Embedding(300,3),nn.Linear(3,256*256),
                        nn.ReLU(True))
    # self.label_condition_disc = nn.Sequential(nn.Linear(3,3*256*256),
    #                     nn.ReLU(True))
    self.embed_sketch = nn.Embedding(1,self.image_size*self.image_size)

def forward(self,inputs):
    x,sketch,label = inputs
    label_output = self.label_condition_color(label)
    label_output = label_output.view(-1,3,self.image_size,self.image_size)
    print(sketch.shape)
    # sketch_embed = self.embed_sketch(sketch).view(sketch.shape[0],1,self.image_size,self.image_size)
    # print("skecth embed shape is:",sketch_embed.shape)
    concat = torch.cat((x, label_output), dim=1)
    x = self.main(concat)
    x = F.avg_pool2d(x, x.size()[2:])
    
    x = torch.flatten(x, 1)
    return x.