How to pass the below tensor to a embedding Layer

Target_color = torch.zeros(1,1,3, dtype=torch.float)
Target_color[:,:,2] = 255
Target_color = Target_color.to(DEVICE)

below is my model

class Discriminator(nn.Module):
def init(self,num_classes,image_size):
super(Discriminator, self).init()
self.image_size = image_size

    self.main = nn.Sequential(
        nn.Conv2d(6, 64, 4, stride=2, padding=1),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(64, 128, 4, stride=2, padding=1),
        nn.InstanceNorm2d(128),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(128, 256, 4, stride=2, padding=1),
        nn.InstanceNorm2d(256),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(256, 512, 4, padding=1),
        nn.InstanceNorm2d(512),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(512, 1, 4, padding=1),
        nn.Sigmoid()
    )
    # self.label_condition_disc = nn.Sequential(nn.Embedding(1, 3),nn.Linear(3, 3*256*256),
    #                             nn.ReLU(True))
    self.embed = nn.Embedding(num_classes,self.image_size*self.image_size)

def forward(self, inputs):
    x,label = inputs
    label_output = self.embed(label)
    label_output = label_output.view(-1,3,self.image_size,self.image_size)
    concat = torch.cat((x, label_output), dim=1)
    x = self.main(concat)
    x = F.avg_pool2d(x, x.size()[2:])
    
    x = torch.flatten(x, 1)
    return x
embed=torch.nn.Embedding(10, 100)
embed(torch.Tensor([0,1,1,4])) # This will give you an error
embed(torch.Tensor([0,1,1,4]).type(torch.long)) # You need to convert it to int or long

Hi anant,

Have to use torch.tensor([255,0,0]) to the embedding layer …

Also how did u decide to use of embedding 10,100

@KURUVILLA_ABRAHAM the (10,100) is just an example to showcase.

In order to pass torch.tensor([255,0,0]) to the embedding, assign it as torch.LongTensor([255, 0, 0])

Ok can you also tell why and what embedding size we will be using for the same tensor

I am passing this way -

embed=torch.nn.Embedding(300, 3)
embed(torch.Tensor([255,0,0]).type(torch.long))

Is this the correct way.

Yes this is the correct way :+1: