Hi Ptrblck,
I really appreciate your help. My aim is running the conditional GAN in which the labels should be embedded and concatenated to the real data for discriminator and the noise for generator.
My labels are different float numbers from 0 to 100. I used the one-hot vector to convert each float labels to the one-hot vector and then as you said get “argmax” to have index for each label.
The maximum value from “argmax” can be 400. My question is that is “nn.embedding(401,10)” correct in the current code? The size of the Labes44 is 64 including different indices from 0 to 400.
class Generator113D_v1(nn.Module):
def __init__(self,ngpu,nz,ngf):
super(Generator113D_v1, self).__init__()
## ---- embedding 401 different numbers from argmax one-hot vector to dim of 10
self.embedding=nn.Embedding(401, 10)
self.ngpu=ngpu
self.nz=nz
self.ngf=ngf
self.l1= nn.Sequential(
nn.ConvTranspose3d(self.nz+10, self.ngf * 8, 3, 1, 0, bias=False),
nn.BatchNorm3d(self.ngf * 8),
nn.ReLU(True),)
self.l2=nn.Sequential(nn.ConvTranspose3d(self.ngf * 8, self.ngf * 4, 3, 2, 0, bias=False),
nn.BatchNorm3d(self.ngf * 4),
nn.ReLU(True),)
self.l3=nn.Sequential(nn.ConvTranspose3d( self.ngf * 4, self.ngf * 2, 3, 1, 0, bias=False),
nn.BatchNorm3d(self.ngf * 2),
nn.ReLU(True),)
self.l4=nn.Sequential(nn.ConvTranspose3d( self.ngf*2, 1, 3, 1, 0, bias=False),nn.Sigmoid()
)
def forward(self, input,Labels):
Out1=self.embedding(Labels)
print("1",Out1.shape)
## ---- concatenate labels and noise from channels
Out1=Out1.unsqueeze(2).unsqueeze(3).unsqueeze(4)
Out2=torch.cat((Out1,input),1)
# print("2",Out2.shape)
out=self.l1(Out2)
# print("3",out.shape)
out=self.l2(out)
out=self.l3(out)
out=self.l4(out)
return out
class Discriminator4layer113D(nn.Module):
def __init__(self, ngpu,ndf):
super(Discriminator4layer113D, self).__init__()
## ---- embedding 401 different numbers from argmax one-hot vector to dim of 10
self.embedding=nn.Embedding(401, 10)
self.ngpu = ngpu
self.ndf=ndf
self.l1= nn.Sequential(nn.Conv3d(2, self.ndf, 3, 1, 0, bias=False),nn.LeakyReLU(0.2, inplace=True))
self.l2=nn.Sequential(nn.Conv3d(self.ndf, self.ndf * 2, 3, 1, 0, bias=False),nn.BatchNorm3d(ndf * 2),nn.LeakyReLU(0.2, inplace=True))
self.drop_out2 = nn.Dropout(0.5)
self.l3= nn.Sequential(nn.Conv3d(self.ndf * 2, self.ndf * 4, 3, 2, 0, bias=False), nn.BatchNorm3d(ndf * 4), nn.LeakyReLU(0.2, inplace=True))
self.drop_out3 = nn.Dropout(0.5)
self.l4= nn.Sequential(nn.Conv3d(self.ndf * 4, 1, 3, 1, 0, bias=False),nn.Sigmoid())
def forward(self, x,Labels):
Out1=self.embedding(Labels)
# print("d1",Out1.shape)
# print(Out1)
## apply linear layer to convert the size of embdded number to the input size
Out2= nn.Linear(10, x.shape[2]*x.shape[3]*x.shape[4])(Out1)
# print("d2",Out2.shape)
## ---- reshape the label size to the size of input for concatenation
Out3=Out2.view(-1,11,11,11).unsqueeze(1)
# print("d3",Out3.shape)
## ---- concatenate labels and inputs
Out4=torch.cat((x,Out3),1)
out = self.l1(Out4)
out=self.l2(out)
out=self.drop_out2(out)
out=self.l3(out)
out=self.drop_out3(out)
out=self.l4(out)
return out
def make_one_hot(volumein):
ret44= np.array([])
ret44=torch.from_numpy(ret44)
for ii in range(volumein.shape[0]):
volume=volumein[ii]
maxRange = 100
discretisation = .25
## the steps based on the 0.25 for each integer
nUnit =int(1/discretisation)
#this thing power -1 should be an integer
sizeHot = nUnit*(maxRange)+1
index = int(np.round(volume*nUnit))
# print('imma putting there in ',index*discretisation)
ret = np.zeros(sizeHot)
ret[index] = 1
ret22=torch.from_numpy(ret)
ret33=ret22.unsqueeze(0)
ret44=torch.cat((ret44,ret33),0)
return ret44
batchsize=64
## --labels are float number from 0 to 100
Labels11=( (100-1)*torch.rand( batchsize))+1
##-- the output size is 64x401 from (OneHoted)
OneHoted=make_one_hot(Labels11)
Labels33=OneHoted
## ---- size Labels44 is 64
Labels44=torch.argmax(Labels33,1)
for epoch in range(num_epochs):
netD.zero_grad()
netD=netD.float()
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, device=device)
output = netD(real_cpu,Labels44).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
noise = torch.randn(b_size, nz,1, 1, 1, device=device)
netG=netG.float()
label.fill_(fake_label)
fake = netG(noise,Labels44).to(device)
output = netD(fake.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()
# Update D
optimizerD.step()
# (2) Update G network
###########################
netG.zero_grad()
label.fill_(real_label)
output = netD(fake,Labels44).view(-1)
errG = criterion(output, label)
errG.backward()
# Update G
optimizerG.step()