I am not aware of a current method.
You could do something like the following. Basically, the idea is you want to feed the Generator both your label and the noise and generate some image based on it. Then use the Discriminator to not only tell between real and fake, but also to classify the label. This way, the real images will backpropagate both the loss from being “anti-fake” and the loss from the classification. Just keep in mind to “negate” the loss when running backpropagation on the generator.
This is just a modification/extension of the DCGAN tutorial. So you’ll need to read that to get a sense of how to set up the training method here. DCGAN Tutorial — PyTorch Tutorials 1.13.1+cu117 documentation
import torch
import torch.nn as nn
class DirectedGAN(nn.Module):
def __init__(self, channels, hidden):
super(DirectedGAN, self).__init__()
self.main=nn.ModuleList()
self.direct=nn.ModuleList()
for i in reversed(range(4)):
if i==3:
stride=1
padding=0
in_channels=channels
else:
stride=2
padding=1
in_channels=hidden*2**(i+1)
if i==0:
out_channels=3
else:
out_channels=hidden*2**i
self.main.append(nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels,4, stride, padding, bias=False),
nn.BatchNorm2d(out_channels), nn.ReLU(True)))
self.direct.append(nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels,4, stride, padding, bias=False),
nn.BatchNorm2d(out_channels), nn.ReLU(True)))
self.tanh=nn.Tanh()
def forward(self, noise, dir_class): # noise: rand of size (batch, channels, 1, 1), dir_class: same size but torch.full to class number, may want to normalize: class_num/total_classes
for i in range(4):
noise=self.main[i](noise)
dir_class=self.direct[i](dir_class)
noise=noise+dir_class
return self.tanh(noise)
class Director(nn.Module):
def __init__(self, channels, hidden, num_classes):
super(Director, self).__init__()
self.main=nn.ModuleList()
for i in range(3):
stride=2
padding=1
in_channels=hidden*2**i
out_channels = hidden * 2 ** (i+1)
if i==0:
in_channels=3
self.main.append(nn.Sequential(nn.Conv2d(in_channels, out_channels, 4, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),nn.LeakyReLU(0.2, inplace=True)))
self.finout_class=nn.Sequential(nn.Conv2d(out_channels, num_classes, 4, 1, 0, bias=False),
nn.BatchNorm2d(num_classes))
self.finout_descriminator=nn.Sequential(nn.Conv2d(out_channels, 1, 4, 1, 0, bias=False),
nn.BatchNorm2d(1))
self.sigm=nn.Sigmoid()
def forward(self, x):
for i in range(3):
x=self.main[i](x)
y=self.finout_descriminator(x)
x = self.finout_class(x)
return self.sigm(x), self.sigm(y)
channels=100
hidden=64
num_classes=10
gan=DirectedGAN(channels, hidden)
dir=Director(channels, hidden, num_classes)
#example
noise=torch.rand(32, channels, 1, 1)
dir_data=torch.full((32,channels, 1, 1), 3/num_classes)
x=gan(noise, dir_data)
classes, realfake=dir(x)
print(classes.size(), realfake.size())
#get loss
classcrit=nn.CrossEntropyLoss()
realfakecrit=nn.BCELoss()
class_targs=torch.full((32,), 3)
realfake_targs=torch.ones((32))
loss=classcrit(classes.view(32,num_classes), class_targs)+realfakecrit(realfake.view(-1), realfake_targs)
print(loss)