Sure thing, here is the function that is calling:
class DiscGradReg(nn.Module):
"""Discriminator gradient regularization for inversion"""
def __init__(self, gamma: float = 10.):
"""
Args:
net: String representing the network to use for LPIPS loss (default: alex)
"""
super().__init__()
self.gamma = gamma
def forward(self, inpt, output):
# Disc represents the discriminator loss
return (self.gamma / 2.0) * (torch.norm(grad(output.mean(), inpt)[0]) ** 2)
and this is the main training loop
def search(**kwargs):
device = torch.device('cuda')
opts = dnnlib.EasyDict(kwargs) # Command line arguments.
# Determine the number of classes
num_classes = 0
for folders in os.listdir(opts['data']):
num_classes += 1
# Define transforms
tsfms = transforms.Compose([
transforms.Resize(256, interpolation=1),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# Calculate number of steps to accumulate gradient
accumulate_steps = (opts['batch'] // opts['batch_gpu']) if opts.get('batch_gpu') else 1
dataset = ImageFolder(opts['data'], transform=tsfms, target_transform=lambda x: F.one_hot(torch.tensor(x, dtype=torch.int64), num_classes).float())
dataloader = DataLoader(
dataset,
shuffle=True,
batch_size=opts['batch_gpu'] if opts.get('batch_gpu') else opts['batch'],
drop_last=True,
num_workers=opts['workers']
)
encoder_loss = EncoderLoss()
discriminator_grad_reg = DiscGradReg()
# Load Generator in evaluation mode
with dnnlib.util.open_url(opts['pkl']) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device)
G.eval()
with dnnlib.util.open_url(opts['pkl']) as f:
D = legacy.load_network_pkl(f)['D'].to(device)
D.train()
E = FastGANEncoder()
E.to(device)
E.train()
# Initialize optimizers
optim_d = AdamW(D.parameters())
optim_e = AdamW(E.parameters())
print(f"\n\nTraining for {opts['epochs']} epochs with batch size {str(opts['batch_gpu']) + ' and total batch ' + str(opts['batch']) if opts.get('batch_gpu') else int(opts['batch'])} on {num_classes} classes...\n\n")
for epoch in range(opts['epochs']):
running_loss_d = 0.
running_loss_e = 0.
# Iterate over dataset
iters = 0
for i, (imgs, labels) in enumerate(dataloader, 0):
with autocast():
imgs, labels = imgs.cuda(), labels.cuda()
# Encode the batch
z_pred = E(imgs)
# Pass the batch through the generator
reconsts = G(z_pred, labels)
# Pass the images through the discriminator
fake_score = D(reconsts, labels)
real_score = D(imgs.requires_grad_(), labels)
# Calculate the loss
loss_e = encoder_loss(imgs, reconsts, fake_score)
loss_d = fake_score - real_score + discriminator_grad_reg(imgs, real_score.requires_grad_())
# Determine whether or not to back prop
if (iters + 1) % accumulate_steps == 0:
iters = 0
loss_e.backward()
loss_d.backward()
optim_e.step()
optim_d.step()
iters += 1
print(f"Running loss of the discriminator at epoch {epoch + 1}: {running_loss_d}")
print(f"Running loss of the encoder at epoch {epoch + 1}: {running_loss_e}")
I hope that helps, I truly do appreciate your help. The function interfaces with the click
library, so that’s why it loads in options from the dictionary opts
.