Is there any layer like tensorflow's space_to_depth() function?

Is there any layer or module in pytorch which works like tensorflow’s space_to_depth() function? I have not found any concrete implementation of this operation neither in pytorch nor tensorflow! Could you please give me some hints about the implementation?

1 Like

it’s hard to understand this function (even after reading docs), but maybe nn.PixelShuffle is the equivalent.

1 Like

But in here someone said that it is just a reshape function:

class scale_to_depth(nn.Module):
    def __init__(self, block_size=1):
        super(scale_to_depth, self).__init__()
        self.block_size = block_size        
    def forward(self, input):
        inputSize = input.size()
        batch = inputSize[0]
        channel = inputSize[1]
        height = inputSize[2]
        width = inputSize[3]
        return input.view(-1, channel * self.block_size * self.block_size,
                      int(height/self.block_size), int(width/self.block_size))

Here is my implementation!

Another thing which I should say is that my goal is to create a layer to do below operation:

adding a passthrough layer that brings features from an earlier layer at 26 × 26 resolution to a layer with 13*13 resolution

Actually, I would like to use earlier feature map

@chenyuntc. Sorry for my question. Have you got any idea about this operation?

It seems that tf.space_to_depth can’t transform 26x26 to 13x13, but your implementation seems good to your goal.

What if I’d like to use different stride than the size of the windows?

Basically I think the operation with k window should turn a [H, W, C] -> [k, k, x] where x denotes how many times you could fit the window on the input (just like when you reshape the input matrix to a [k*k, x] matrix to evaluate the convolution operation as an inner product - but this time only the reshape would be needed)

Exactly that would be the same as https://www.tensorflow.org/api_docs/python/tf/extract_image_patches

The view-based implementation seems wrong:

In [184]: t = tf.constant(np.arange(32.).reshape(1,4,4,2))

In [185]: t.eval()
Out[185]:
array([[[[  0.,   1.],
         [  2.,   3.],
         [  4.,   5.],
         [  6.,   7.]],

        [[  8.,   9.],
         [ 10.,  11.],
         [ 12.,  13.],
         [ 14.,  15.]],

        [[ 16.,  17.],
         [ 18.,  19.],
         [ 20.,  21.],
         [ 22.,  23.]],

        [[ 24.,  25.],
         [ 26.,  27.],
         [ 28.,  29.],
         [ 30.,  31.]]]])

In [186]: tf.space_to_depth(t,2).eval()
Out[186]:
array([[[[  0.,   1.,   2.,   3.,   8.,   9.,  10.,  11.],
         [  4.,   5.,   6.,   7.,  12.,  13.,  14.,  15.]],

        [[ 16.,  17.,  18.,  19.,  24.,  25.,  26.,  27.],
         [ 20.,  21.,  22.,  23.,  28.,  29.,  30.,  31.]]]])

In [187]: t = torch.Tensor(np.arange(32.).reshape(1,4,4,2))

In [188]: t
Out[188]:

(0 ,0 ,.,.) =
   0   1
   2   3
   4   5
   6   7

(0 ,1 ,.,.) =
   8   9
  10  11
  12  13
  14  15

(0 ,2 ,.,.) =
  16  17
  18  19
  20  21
  22  23

(0 ,3 ,.,.) =
  24  25
  26  27
  28  29
  30  31
[torch.FloatTensor of size 1x4x4x2]

In [189]: t.view(1,2,2,8)
Out[189]:

(0 ,0 ,.,.) =
   0   1   2   3   4   5   6   7
   8   9  10  11  12  13  14  15

(0 ,1 ,.,.) =
  16  17  18  19  20  21  22  23
  24  25  26  27  28  29  30  31
[torch.FloatTensor of size 1x2x2x8]

You need to perform some permutations of the data before calling view, similar to how pixelshuffle is implemented (typing from the phone, else I’d send you the link)

Yes, I realized right after posting but didn’t have internet…

It seems that there is no permutations that can make it work.
I have tried with the following snippet to generate all permutations and check if there is one that is equal to tf.space_to_depth(t,2).eval()

for p in list(itertools.permutations([0,1,2,3]))
        t.permute(p[0], p[1], p[2], p[3]).view(1,2,2,8)

Hi,

I’ve implemented a class for space_to_depth in pytorch by split, stack and permute operations.
Note that it requires input in BCHW format, or you can remove first and last permute in “forward” to make it in BHWC format.
I’ve also done the depth_to_space via this depth_to_space pytorch.
Both were tested, if you’d like to see the testing code, I can upload it as well.

class SpaceToDepth(nn.Module):
    def __init__(self, block_size):
        super(SpaceToDepth, self).__init__()
        self.block_size = block_size
        self.block_size_sq = block_size*block_size

    def forward(self, input):
        output = input.permute(0, 2, 3, 1)
        (batch_size, s_height, s_width, s_depth) = output.size()
        d_depth = s_depth * self.block_size_sq
        d_width = int(s_width / self.block_size)
        d_height = int(s_height / self.block_size)
        t_1 = output.split(self.block_size, 2)
        stack = [t_t.contiguous().view(batch_size, d_height, d_depth) for t_t in t_1]
        output = torch.stack(stack, 1)
        output = output.permute(0, 2, 1, 3)
        output = output.permute(0, 3, 1, 2)
        return output
2 Likes

Update to code to pytorch 0.4

  • contiguous().view() --> reshape()

Updated link: https://gist.github.com/jalola/f41278bb27447bed9cd3fb48ec142aec

2 Likes

The code below works the same as tensorflow,

from torch import nn

class DepthToSpace(nn.Module):

    def __init__(self, block_size):
        super().__init__()
        self.bs = block_size

    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W)  # (N, bs, bs, C//bs^2, H, W)
        x = x.permute(0, 3, 4, 1, 5, 2).contiguous()  # (N, C//bs^2, H, bs, W, bs)
        x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs)  # (N, C//bs^2, H * bs, W * bs)
        return x


class SpaceToDepth(nn.Module):

    def __init__(self, block_size):
        super().__init__()
        self.bs = block_size

    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs)  # (N, C, H//bs, bs, W//bs, bs)
        x = x.permute(0, 3, 5, 1, 2, 4).contiguous()  # (N, bs, bs, C, H//bs, W//bs)
        x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs)  # (N, C*bs^2, H//bs, W//bs)
        return x


import tensorflow as tf
import torch

# pytorch
x1 = torch.rand(64, 256, 8, 8)
x2 = DepthToSpace(2)(x1)
x3 = SpaceToDepth(2)(x2)
print(x1.size())
print(x2.size())
print(x3.size())
print((x1 == x3).all())

# tensorflow
y1 = tf.transpose(x1.numpy(), [0, 2, 3, 1])  # NCHW -> NHWC
y2 = tf.depth_to_space(y1, 2)
y3 = tf.space_to_depth(y2, 2)

y1 = tf.transpose(y1, [0, 3, 1, 2])  # NHWC -> NCHW
y2 = tf.transpose(y2, [0, 3, 1, 2])
y3 = tf.transpose(y3, [0, 3, 1, 2])

y1, y2, y3 = tf.Session().run([y1, y2, y3])
print(y1.shape)
print(y2.shape)
print(y3.shape)
print((y1 == y3).all())

# check consistency
print((x1.numpy() == y1).all())
print((x2.numpy() == y2).all())
print((x3.numpy() == y3).all())
11 Likes

Thank for sharing, ^^ it work for me