The code below works the same as tensorflow,
from torch import nn
class DepthToSpace(nn.Module):
def __init__(self, block_size):
super().__init__()
self.bs = block_size
def forward(self, x):
N, C, H, W = x.size()
x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
return x
class SpaceToDepth(nn.Module):
def __init__(self, block_size):
super().__init__()
self.bs = block_size
def forward(self, x):
N, C, H, W = x.size()
x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
return x
import tensorflow as tf
import torch
# pytorch
x1 = torch.rand(64, 256, 8, 8)
x2 = DepthToSpace(2)(x1)
x3 = SpaceToDepth(2)(x2)
print(x1.size())
print(x2.size())
print(x3.size())
print((x1 == x3).all())
# tensorflow
y1 = tf.transpose(x1.numpy(), [0, 2, 3, 1]) # NCHW -> NHWC
y2 = tf.depth_to_space(y1, 2)
y3 = tf.space_to_depth(y2, 2)
y1 = tf.transpose(y1, [0, 3, 1, 2]) # NHWC -> NCHW
y2 = tf.transpose(y2, [0, 3, 1, 2])
y3 = tf.transpose(y3, [0, 3, 1, 2])
y1, y2, y3 = tf.Session().run([y1, y2, y3])
print(y1.shape)
print(y2.shape)
print(y3.shape)
print((y1 == y3).all())
# check consistency
print((x1.numpy() == y1).all())
print((x2.numpy() == y2).all())
print((x3.numpy() == y3).all())