What kind of methods can constrain the GAN to generate related class image

I’m trying to use the StarGANv2 with SPADE block to generate the image which remain the original content, and control the feature of certain location throught use the semantic information as the input of the SPADE block, e.g., generate the type 1 defect on the door.

In some case, the defect may have multi type => type 1 and type 2 both appear on the door

My question is: how can I constrain the model to generate the certain type defect on the image ?

I want to add a classifier (not discriminator) with cross-entropy loss to instruct the generator.

but I don’t know how to solve problems when multi type?

And I want to know the class loss is lookforward to update the generator or the classifier?

I am not aware of a current method.

You could do something like the following. Basically, the idea is you want to feed the Generator both your label and the noise and generate some image based on it. Then use the Discriminator to not only tell between real and fake, but also to classify the label. This way, the real images will backpropagate both the loss from being “anti-fake” and the loss from the classification. Just keep in mind to “negate” the loss when running backpropagation on the generator.

This is just a modification/extension of the DCGAN tutorial. So you’ll need to read that to get a sense of how to set up the training method here. DCGAN Tutorial — PyTorch Tutorials 1.13.1+cu117 documentation

import torch
import torch.nn as nn

class DirectedGAN(nn.Module):
    def __init__(self, channels, hidden):
        super(DirectedGAN, self).__init__()
        self.main=nn.ModuleList()
        self.direct=nn.ModuleList()
        for i in reversed(range(4)):
            if i==3:
                stride=1
                padding=0
                in_channels=channels
            else:
                stride=2
                padding=1
                in_channels=hidden*2**(i+1)

            if i==0:
                out_channels=3
            else:
                out_channels=hidden*2**i
            self.main.append(nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels,4, stride, padding, bias=False),
                                           nn.BatchNorm2d(out_channels), nn.ReLU(True)))
            self.direct.append(nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels,4, stride, padding, bias=False),
                                           nn.BatchNorm2d(out_channels), nn.ReLU(True)))

            self.tanh=nn.Tanh()

    def forward(self, noise, dir_class): # noise: rand of size (batch, channels, 1, 1), dir_class: same size but torch.full to class number, may want to normalize: class_num/total_classes

        for i in range(4):
            noise=self.main[i](noise)
            dir_class=self.direct[i](dir_class)
            noise=noise+dir_class

        return self.tanh(noise)

class Director(nn.Module):
    def __init__(self, channels, hidden, num_classes):
        super(Director, self).__init__()
        self.main=nn.ModuleList()

        for i in range(3):
            stride=2
            padding=1
            in_channels=hidden*2**i
            out_channels = hidden * 2 ** (i+1)
            if i==0:
                in_channels=3

            self.main.append(nn.Sequential(nn.Conv2d(in_channels, out_channels, 4, stride, padding, bias=False),
                                    nn.BatchNorm2d(out_channels),nn.LeakyReLU(0.2, inplace=True)))

        self.finout_class=nn.Sequential(nn.Conv2d(out_channels, num_classes, 4, 1, 0, bias=False),
                                    nn.BatchNorm2d(num_classes))
        self.finout_descriminator=nn.Sequential(nn.Conv2d(out_channels, 1, 4, 1, 0, bias=False),
                                    nn.BatchNorm2d(1))

        self.sigm=nn.Sigmoid()

    def forward(self, x):
        for i in range(3):
            x=self.main[i](x)
        y=self.finout_descriminator(x)
        x = self.finout_class(x)
        return self.sigm(x), self.sigm(y)

channels=100
hidden=64
num_classes=10
gan=DirectedGAN(channels, hidden)
dir=Director(channels, hidden, num_classes)

#example
noise=torch.rand(32, channels, 1, 1)
dir_data=torch.full((32,channels, 1, 1), 3/num_classes)

x=gan(noise, dir_data)
classes, realfake=dir(x)
print(classes.size(), realfake.size())

#get loss
classcrit=nn.CrossEntropyLoss()
realfakecrit=nn.BCELoss()

class_targs=torch.full((32,), 3)
realfake_targs=torch.ones((32))
loss=classcrit(classes.view(32,num_classes), class_targs)+realfakecrit(realfake.view(-1), realfake_targs)
print(loss)