Unusually Slow behavior when using PixelShuffleLayer

Hi,

I’m trying to implement SRGAN in pytorch. The paper proposes a pixel shuffle layer in order to upscale images. However, when using such layer, the time per epoch increases drastically, a behavior not observed when running the same code in Tensorflow. It seems that the -pytorch- code is not utilizing the GPU correctly, as the volatile gpu memory oscillates between 10-90% : when removing the layer it stays around 98-99%, as it does with regular gans. Therefore, I’m wondering if the slow behavior is because of my code, or due to pytorch.

Here is my pytorch implementation :

def pixel_shuffle_layer(x, r, n_split):
    def PS(x, r):
        # assumes tf ordering
        bs, a, b, c = x.size()
        x = x.contiguous().view(bs, a, b, r, r)
        x = x.permute(0,1,2,4,3)
        x = torch.chunk(x, a, dim=1)
        x = torch.cat([x_.squeeze() for x_ in x], dim=2)
        x = torch.chunk(x, b, dim=1)
        x = torch.cat([x_.squeeze() for x_ in x], dim=2)
        return x.view(bs, a*r, b*r, 1)

    # put in tf ordering
    x = x.permute(0,2,3,1)
    xc = torch.chunk(x, n_split, dim=3)
    x = torch.cat([PS(x_, r) for x_ in xc], dim=3)
    # put back in th ordering
    x = x.permute(0,3,1,2)
    return x

Here is the tensorflow implementation

def pixel_shuffle_layer(x, r, n_split):
    def PS(x, r):
        bs, a, b, c = x.get_shape().as_list()
        x = tf.reshape(x, (bs, a, b, r, r))
        x = tf.transpose(x, (0, 1, 2, 4, 3))
        x = tf.split(x, a, 1)
        x = tf.concat([tf.squeeze(x_) for x_ in x], 2)
        x = tf.split(x, b, 1)
        x = tf.concat([tf.squeeze(x_) for x_ in x], 2)
        return tf.reshape(x, (bs, a*r, b*r, 1))

    xc = tf.split(x, n_split, 3)
    return tf.concat([PS(x_, r) for x_ in xc], 3)

Finally, here is how I’m calling the layer in my forward pass :

def forward(self, input):
    val = self.deconv_one(input)
    val = nn.ReLU()(val)
    shortcut = val 

    for i in range(self.B):
        mid = val
        val = self.blocks[i](val)
        val = val + mid

    val = self.deconv_a(val)
    if self.bn : val = self.bn_a(val)
    val = val + shortcut

    val = self.deconv_b(val)
    bs, C, H, W = val.size()
    val = pixel_shuffle_layer(val, 2, C / 2**2)
    val = nn.ReLU()(val)

    val = self.deconv_c(val)
    bs, C, H, W = val.size()
    val = pixel_shuffle_layer(val, 2, C / 2**2)
    val = nn.ReLU()(val)

    val = self.deconv_d(val)
    if self.last_nl is not None : 
        val = self.last_nl(val)

    return val

Thanks in advance,
Lucas

Just in case, there is a pixelshuffle implementation in Pytorch already, did you try that?
http://pytorch.org/docs/master/nn.html#torch.nn.PixelShuffle

1 Like

Hi,

I just tried the PixelShuffle Layer you linked. Works great!

Thanks a lot :slight_smile: