Hi,
I’m trying to implement SRGAN in pytorch. The paper proposes a pixel shuffle layer in order to upscale images. However, when using such layer, the time per epoch increases drastically, a behavior not observed when running the same code in Tensorflow. It seems that the -pytorch- code is not utilizing the GPU correctly, as the volatile gpu memory oscillates between 10-90% : when removing the layer it stays around 98-99%, as it does with regular gans. Therefore, I’m wondering if the slow behavior is because of my code, or due to pytorch.
Here is my pytorch implementation :
def pixel_shuffle_layer(x, r, n_split):
def PS(x, r):
# assumes tf ordering
bs, a, b, c = x.size()
x = x.contiguous().view(bs, a, b, r, r)
x = x.permute(0,1,2,4,3)
x = torch.chunk(x, a, dim=1)
x = torch.cat([x_.squeeze() for x_ in x], dim=2)
x = torch.chunk(x, b, dim=1)
x = torch.cat([x_.squeeze() for x_ in x], dim=2)
return x.view(bs, a*r, b*r, 1)
# put in tf ordering
x = x.permute(0,2,3,1)
xc = torch.chunk(x, n_split, dim=3)
x = torch.cat([PS(x_, r) for x_ in xc], dim=3)
# put back in th ordering
x = x.permute(0,3,1,2)
return x
Here is the tensorflow implementation
def pixel_shuffle_layer(x, r, n_split):
def PS(x, r):
bs, a, b, c = x.get_shape().as_list()
x = tf.reshape(x, (bs, a, b, r, r))
x = tf.transpose(x, (0, 1, 2, 4, 3))
x = tf.split(x, a, 1)
x = tf.concat([tf.squeeze(x_) for x_ in x], 2)
x = tf.split(x, b, 1)
x = tf.concat([tf.squeeze(x_) for x_ in x], 2)
return tf.reshape(x, (bs, a*r, b*r, 1))
xc = tf.split(x, n_split, 3)
return tf.concat([PS(x_, r) for x_ in xc], 3)
Finally, here is how I’m calling the layer in my forward pass :
def forward(self, input):
val = self.deconv_one(input)
val = nn.ReLU()(val)
shortcut = val
for i in range(self.B):
mid = val
val = self.blocks[i](val)
val = val + mid
val = self.deconv_a(val)
if self.bn : val = self.bn_a(val)
val = val + shortcut
val = self.deconv_b(val)
bs, C, H, W = val.size()
val = pixel_shuffle_layer(val, 2, C / 2**2)
val = nn.ReLU()(val)
val = self.deconv_c(val)
bs, C, H, W = val.size()
val = pixel_shuffle_layer(val, 2, C / 2**2)
val = nn.ReLU()(val)
val = self.deconv_d(val)
if self.last_nl is not None :
val = self.last_nl(val)
return val
Thanks in advance,
Lucas