I changed in this way. The output of the embedding is 64x10 which it is correct since my batch size is 64 and output of the embedding is vector by 10 dimension.
class Discriminator(nn.Module):
def __init__(self, ngpu,ndf):
super(Discriminator, self).__init__()
## --define embedding for 64 differente labels and map them to dim of 10
self.embedding=nn.Embedding(401, 10)
self.ngpu = ngpu
self.ndf=ndf
self.l=nn.Linear(10,1331)
self.l1= nn.Sequential(nn.Conv3d(2, self.ndf, 3, 1, 0, bias=False),nn.LeakyReLU(0.2, inplace=True))
self.l2=nn.Sequential(nn.Conv3d(self.ndf, self.ndf * 2, 3, 1, 0, bias=False),nn.BatchNorm3d(ndf * 2),nn.LeakyReLU(0.2, inplace=True))
self.drop_out2 = nn.Dropout(0.5)
self.l3= nn.Sequential(nn.Conv3d(self.ndf * 2, self.ndf * 4, 3, 2, 0, bias=False), nn.BatchNorm3d(ndf * 4), nn.LeakyReLU(0.2, inplace=True))
self.drop_out3 = nn.Dropout(0.5)
self.l4= nn.Sequential(nn.Conv3d(self.ndf * 4, 1, 3, 1, 0, bias=False),nn.Sigmoid())
def forward(self, x,Labels):
Labels=Labels.squeeze(1).squeeze(1).squeeze(1)
Out1=self.embedding(Labels)
Out2= self.l(Out1)
## ---- reshape the label size to the size of input for concatenation
Out3=Out2.view(-1,11,11,11).unsqueeze(1)
## ---- concatenate labels and inputs
Out4=torch.cat((x,Out3),1)
out = self.l1(Out4)
out=self.l2(out)
out=self.drop_out2(out)
out=self.l3(out)
out=self.drop_out3(out)
out=self.l4(out)
return out
class Generator(nn.Module):
def __init__(self,ngpu,nz,ngf):
super(Generator, self).__init__()
self.ngpu=ngpu
self.nz=nz
self.ngf=ngf
self.embedding=nn.Embedding(401, 10)
self.l1= nn.Sequential( nn.ConvTranspose3d(self.nz+10, self.ngf * 8, 3, 1, 0, bias=False),
nn.BatchNorm3d(self.ngf * 8),
nn.ReLU(True))
self.l2= nn.Sequential(nn.ConvTranspose3d(self.ngf * 8, self.ngf * 4, 3, 1, 0, bias=False),
nn.BatchNorm3d(self.ngf * 4),
nn.ReLU(True))
self.l3= nn.Sequential(nn.ConvTranspose3d( self.ngf * 4, self.ngf * 2, 3, 1, 0, bias=False),
nn.BatchNorm3d(self.ngf * 2),
nn.ReLU(True))
self.l4= nn.Sequential(nn.ConvTranspose3d( self.ngf*2, 1, 3, 1, 0, bias=False),nn.Sigmoid())
def forward(self, input,Labels,Sigmad):
Labels=Labels.squeeze(1).squeeze(1).squeeze(1)
Out1=self.embedding(Labels)
## ---- concatenate labels and noise from channels
Out1=Out1.unsqueeze(2).unsqueeze(3).unsqueeze(4)
Out2=torch.cat((Out1,input),1)
out=self.l1(Out2)
out=self.l2(out)
out=self.l3(out)
out=self.l4(out)*Sigmad
return out