Not sure how to translate parts of tf/keras model/training to pytorch?

I’m trying to convert this model and training code for pytorch (originally taken from HERE):

import os
import shutil

# example of pix2pix gan for satellite to map image-to-image translation
from numpy import load
from numpy import zeros
from numpy import ones
from numpy.random import randint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import LeakyReLU
from matplotlib import pyplot

# define the discriminator model
def define_discriminator(image_shape):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# source image input
	in_src_image = Input(shape=image_shape)
	# target image input
	in_target_image = Input(shape=image_shape)
	# concatenate images channel-wise
	merged = Concatenate()([in_src_image, in_target_image])
	# C64
	d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
	d = LeakyReLU(alpha=0.2)(d)
	# C128
	d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C256
	d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C512
	d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# second last output layer
	d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# patch output
	d = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
	patch_out = Activation('sigmoid')(d)
	# define model
	model = Model([in_src_image, in_target_image], patch_out)
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
	return model

# define an encoder block
def define_encoder_block(layer_in, n_filters, batchnorm=True):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# add downsampling layer
	g = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
	# conditionally add batch normalization
	if batchnorm:
		g = BatchNormalization()(g, training=True)
	# leaky relu activation
	g = LeakyReLU(alpha=0.2)(g)
	return g

# define a decoder block
def decoder_block(layer_in, skip_in, n_filters, dropout=True):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# add upsampling layer
	g = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
	# add batch normalization
	g = BatchNormalization()(g, training=True)
	# conditionally add dropout
	if dropout:
		g = Dropout(0.5)(g, training=True)
	# merge with skip connection
	g = Concatenate()([g, skip_in])
	# relu activation
	g = Activation('relu')(g)
	return g

# define the standalone generator model
def define_generator(image_shape=(256,256,3)):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image input
	in_image = Input(shape=image_shape)
	# encoder model
	e1 = define_encoder_block(in_image, 64, batchnorm=False)
	e2 = define_encoder_block(e1, 128)
	e3 = define_encoder_block(e2, 256)
	e4 = define_encoder_block(e3, 512)
	e5 = define_encoder_block(e4, 512)
	e6 = define_encoder_block(e5, 512)
	e7 = define_encoder_block(e6, 512)
	# bottleneck, no batch norm and relu
	b = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)
	b = Activation('relu')(b)
	# decoder model
	d1 = decoder_block(b, e7, 512)
	d2 = decoder_block(d1, e6, 512)
	d3 = decoder_block(d2, e5, 512)
	d4 = decoder_block(d3, e4, 512, dropout=False)
	d5 = decoder_block(d4, e3, 256, dropout=False)
	d6 = decoder_block(d5, e2, 128, dropout=False)
	d7 = decoder_block(d6, e1, 64, dropout=False)
	# output
	g = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)
	out_image = Activation('tanh')(g)
	# define model
	model = Model(in_image, out_image)
	return model

# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model, image_shape):
	# make weights in the discriminator not trainable
	for layer in d_model.layers:
		if not isinstance(layer, BatchNormalization):
			layer.trainable = False
	# define the source image
	in_src = Input(shape=image_shape)
	# connect the source image to the generator input
	gen_out = g_model(in_src)
	# connect the source input and generator output to the discriminator input
	dis_out = d_model([in_src, gen_out])
	# src image as input, generated image and classification output
	model = Model(in_src, [dis_out, gen_out])
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])
	return model

# load and prepare training images
def load_real_samples(filename):
	# load compressed arrays
	data = load(filename)
	# unpack arrays
	X1, X2 = data['arr_0'], data['arr_1']
	# scale from [0,255] to [-1,1]
	X1 = (X1 - 127.5) / 127.5
	X2 = (X2 - 127.5) / 127.5
	return [X1, X2]

# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):
	# unpack dataset
	trainA, trainB = dataset
	# choose random instances
	ix = randint(0, trainA.shape[0], n_samples)
	# retrieve selected images
	X1, X2 = trainA[ix], trainB[ix]
	# generate 'real' class labels (1)
	y = ones((n_samples, patch_shape, patch_shape, 1))
	return [X1, X2], y

# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, samples, patch_shape):
	# generate fake instance
	X = g_model.predict(samples)
	# create 'fake' class labels (0)
	y = zeros((len(X), patch_shape, patch_shape, 1))
	return X, y

# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, dataset, n_samples=3):
	# select a sample of input images
	[X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)
	# generate a batch of fake samples
	X_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)
	# scale all pixels from [-1,1] to [0,1]
	X_realA = (X_realA + 1) / 2.0
	X_realB = (X_realB + 1) / 2.0
	X_fakeB = (X_fakeB + 1) / 2.0
	# plot real source images
	for i in range(n_samples):
		pyplot.subplot(3, n_samples, 1 + i)
		pyplot.axis('off')
		pyplot.imshow(X_realA[i])
	# plot generated target image
	for i in range(n_samples):
		pyplot.subplot(3, n_samples, 1 + n_samples + i)
		pyplot.axis('off')
		pyplot.imshow(X_fakeB[i])
	# plot real target image
	for i in range(n_samples):
		pyplot.subplot(3, n_samples, 1 + n_samples*2 + i)
		pyplot.axis('off')
		pyplot.imshow(X_realB[i])
	# save plot to file
	filename1 = 'plot_%06d.png' % (step+1)
	pyplot.savefig(os.path.join(GRAPHS_DIR, filename1))
	pyplot.close()

# train pix2pix models
def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1):
	# determine the output square shape of the discriminator
	n_patch = d_model.output_shape[1]
	# unpack dataset
	trainA, trainB = dataset
	# calculate the number of batches per training epoch
	bat_per_epo = int(len(trainA) / n_batch)
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# manually enumerate epochs
	for i in range(n_steps):
		# select a batch of real samples
		[X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)
		# generate a batch of fake samples
		X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
		# update discriminator for real samples
		d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
		# update discriminator for generated samples
		d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
		# update the generator
		g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])
		# summarize performance
		print('\r>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss), end='')
		# summarize model performance
		if (i+1) % (bat_per_epo * 1) == 0:
			summarize_performance(i, g_model, dataset)
		# if (i+1) % (bat_per_epo * 10) == 0:
			# save the generator model
			filename2 = 'model_%06d.h5' % (i+1)
			g_model.save(os.path.join(MODELS_DIR, filename2))
			print('>Saved: %s' % filename2)


if __name__ == '__main__':
	os.system('cls')

	MODELS_DIR = "models_per_epoch"
	if os.path.exists(MODELS_DIR):
		shutil.rmtree(MODELS_DIR)
	os.mkdir(MODELS_DIR)

	GRAPHS_DIR = "graphs"
	if os.path.exists(GRAPHS_DIR):
		shutil.rmtree(GRAPHS_DIR)
	os.mkdir(GRAPHS_DIR)

	# load image data
	dataset = load_real_samples('data/data_full.npz')
	print('Loaded', dataset[0].shape, dataset[1].shape)
	# dataset[0] = dataset[0][:10]
	# dataset[1] = dataset[1][:10]
	# define input shape based on the loaded dataset
	image_shape = dataset[0].shape[1:]
	# define the models
	d_model = define_discriminator(image_shape)
	g_model = define_generator(image_shape)
	# define the composite model
	gan_model = define_gan(g_model, d_model, image_shape)
	# train model
	train(d_model, g_model, gan_model, dataset)

I believe I have accurately translated the model so far (if anyone notices any in inequivalencies, please let me know!):

import os
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np


def conv(in_channels, out_channels, kernel_size=4, stride=2, padding="same", batch_norm=True):
	layers = []
	layers.append(nn.Conv2d(
		in_channels=in_channels, 
		out_channels=out_channels, 
		kernel_size=kernel_size, 
		stride=stride, 
		padding=padding, 
		bias=False
	))
	if batch_norm:
		layers.append(nn.BatchNorm2d(out_channels))
	return nn.Sequential(*layers)


def deconv(in_channels, out_channels, kernel_size=4, stride=2, padding="same", batch_norm=True):
	layers = []
	layers.append(nn.ConvTranspose2d(
		in_channels=in_channels, 
		out_channels=out_channels, 
		kernel_size=kernel_size, 
		stride=stride, 
		padding=padding, 
		bias=False
	))
	if batch_norm:
		layers.append(nn.BatchNorm2d(out_channels))
	return nn.Sequential(*layers)


class EncoderBlock(nn.Module):
	
	def __init__(self, in_channels, out_channels, batch_norm=True):
		super(EncoderBlock, self).__init__()
		self.conv1 = conv(in_channels=in_channels, out_channels=out_channels, batch_norm=batch_norm)

	def forward(self, x):
		out = F.leaky_relu(x, .2)
		return out


class DecoderBlock(nn.Module):
	
	def __init__(self, in_channels, out_channels, dropout=True):
		super(DecoderBlock, self).__init__()
		self.deconv1 = deconv(in_channels=in_channels, out_channels=out_channels, batch_norm=True)
		self.dropout = dropout

	def forward(self, x, prev_out):
		out = self.deconv1(x)
		if self.dropout:
			out = F.dropout(out, .5)
		out = torch.cat([out, prev_out])
		out = F.relu(out)
		return out


class Discriminator(nn.Module):
	
	def __init__(self):
		super(Discriminator, self).__init__()
		self.conv1 = conv(6, 64)
		self.conv2 = conv(64, 128)
		self.conv3 = conv(128, 256)
		self.conv4 = conv(256, 512)
		self.conv5 = conv(512, 512)
		self.conv6 = conv(512, 1)
		self.leaky_relu = nn.LeakyReLU(.2)

	def forward(self, x, y):
		out = torch.cat([x, y])
		out = self.leaky_relu(self.conv1(x))
		out = self.leaky_relu(self.conv2(out))
		out = self.leaky_relu(self.conv3(out))
		out = self.leaky_relu(self.conv4(out))
		out = self.leaky_relu(self.conv5(out))
		out = F.sigmoid(self.conv6(out))
		return out


class Generator(nn.Module):
	
	def __init__(self):
		super(Generator, self).__init__()

		self.e1 = EncoderBlock(3, 64, batch_norm=False)
		self.e2 = EncoderBlock(64, 128),
		self.e3 = EncoderBlock(128, 256),
		self.e4 = EncoderBlock(256, 512),
		self.e5 = EncoderBlock(512, 512),
		self.e6 = EncoderBlock(512, 512),
		self.e7 = EncoderBlock(512, 512),

		self.b = conv(512, 512, batch_norm=False)

		self.d1 = DecoderBlock(512, 512),
		self.d2 = DecoderBlock(512, 512),
		self.d3 = DecoderBlock(512, 512),
		self.d4 = DecoderBlock(512, 512, dropout=False),
		self.d5 = DecoderBlock(512, 256, dropout=False),
		self.d6 = DecoderBlock(256, 128, dropout=False),
		self.d7 = DecoderBlock(128, 64, dropout=False),

		self.deconv1 = deconv(64, 3)

	def forward(self, x):

		e1 = self.e1(x)
		e2 = self.e1(e1)
		e3 = self.e1(e2)
		e4 = self.e1(e3)
		e5 = self.e1(e4)
		e6 = self.e1(e5)
		e7 = self.e1(e6)

		b = F.relu(self.b(e7))

		d1 = self.d1(b, e7)
		d2 = self.d2(d1, e6)
		d3 = self.d3(d2, e5)
		d4 = self.d4(d3, e4)
		d5 = self.d5(d4, e3)
		d6 = self.d6(d5, e2)
		d7 = self.d7(d6, e1)

		out = F.tanh(self.deconv1(d7))
		return out


class GAN(nn.Module):

	def __init__(self, generator, discriminator):
		super(GAN, self).__init__()
		
		for layer in discriminator.children():
			if not isinstance(layer, nn.BatchNorm2d):
				layer.eval()
				layer.track_running_stats = False

		self.generator = generator
		self.discriminator = discriminator

	def forward(self, x):
		g_out = self.generator(x)
		
		d_in = torch.cat([x, g_out])
		d_out = self.discriminator(d_in)

		out = torch.cat([d_out, g_out])
		return out

But for the training loop, how do I translate this?:

def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1):
	n_patch = d_model.output_shape[1]
	trainA, trainB = dataset
	bat_per_epo = int(len(trainA) / n_batch)
	n_steps = bat_per_epo * n_epochs
	for i in range(n_steps):
		[X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)
		X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
		d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
		d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
		g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])

Specifically:

  • do I need to define individual train and eval loops for the generator, discriminator, and gan?
  • in the tf/keras version, only the GAN has a specified optimizer and loss function, do I only need to specify the optimizer and loss for the GAN in pytorch as well?