BAGAN implementation from Tensorflow to Pytorch

Hello. I’m currently trying to rewrite the implementation of BAGAN from IBM in order to use in a personal project to see if it would help with the imbalanced dataset issue I’m facing right now.
However I’m having some trouble trying to replicate the same behavior as seem under on the Generator part:

def build_generator(self, latent_size, init_resolution=8):
        resolution = self.resolution
        channels = self.channels

        # we will map a pair of (z, L), where z is a latent vector and L is a
        # label drawn from P_c, to image space (..., 3, resolution, resolution)
        cnn = Sequential()

        cnn.add(Dense(1024, input_dim=latent_size, activation='relu', use_bias=False))
        cnn.add(Dense(128 * init_resolution * init_resolution, activation='relu', use_bias=False))
        cnn.add(Reshape((128, init_resolution, init_resolution)))
        crt_res = init_resolution

        # upsample
        while crt_res != resolution:
            cnn.add(UpSampling2D(size=(2, 2)))
            if crt_res < resolution/2:
                cnn.add(Conv2D(
                    256, (5, 5), padding='same',
                    activation='relu', kernel_initializer='glorot_normal', use_bias=False)
                )

            else:
                cnn.add(Conv2D(128, (5, 5), padding='same',
                                      activation='relu', kernel_initializer='glorot_normal', use_bias=False))

            crt_res = crt_res * 2
            assert crt_res <= resolution,\
                "Error: final resolution [{}] must equal i*2^n. Initial resolution i is [{}]. n must be a natural number.".format(resolution, init_resolution)

        cnn.add(Conv2D(channels, (2, 2), padding='same',
                              activation='tanh', kernel_initializer='glorot_normal', use_bias=False))

It has this Upsample part where he adds some modules depending on some condition. However in PyTorch we need to specify the input_size, how could I keep track of it, and make sure on the last Conv2d layer I’m not making anything wrong? If possible and someone want to share a “translation” from Tensorflow to PyTorch from this implementation I would be grateful.

Thanks!

Just to give some context, this is what I came up to in PyTorch:

class BalancingGAN:
    def build_generator(self, latent_size, init_resolution=8, resolution=64, channels=3):
      #resolution = self.resolution
      #channels = self.channels

      cnn_torch = nn.Sequential()

      cnn_torch.add_module("Linear1", nn.Linear(latent_size, 1024, bias=False))
      cnn_torch.add_module("ReLu1", nn.ReLU(inplace=True))
      cnn_torch.add_module("Linear2", nn.Linear(1024, 128 * init_resolution * init_resolution, bias=False))
      cnn_torch.add_module("ReLu2", nn.ReLU(inplace=True))

      crt_res = init_resolution 

      #upsample
      while crt_res != resolution: 
        cnn_torch.add_module("Upsample", nn.UpsamplingNearest2d(scale_factor=2)) 
        if crt_res < resolution/2:
          cnn_torch.add_module("Upsample Conv2d", nn.Conv2d(input_size, 256, (5, 5), padding=1, bias=False)) 
          cnn_torch.add_module("Upsample ReLu", nn.ReLu(inplace=True)) 
        else:
          cnn_torch.add_module("Upsample Conv2d - else", nn.Conv2d(input_size, 128, (5, 5), padding=1, bias=False))

        crt_res = crt_res * 2
        assert crt_res <= resolution,\
                "Error: final resolution [{}] must equal i*2^n. Initial resolution i is [{}]. n must be a natural number.".format(resolution, init_resolution)

      cnn_torch.add_module("Conv2d final", nn.Conv2d(input_size, channels, (2,2), padding=1, bias=False))
      cnn_torch.add_module("Tanh", nn.Tanh())

the “input_size” on the Conv2d layers on the upsample part is just a placeholder while I was trying to figure out what to actually pass there.
If something is wrong/should be done another way please feel free to say so.

About the ‘glorot_normal’ I should do something like that, right? And then apply to the model

def weight_init_test(m):
    if isinstance(m, nn.Con2d):
      nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu'))
      nn.init.zeros_(m.bias)