Hi,
I’ve implemented a class for space_to_depth in pytorch by split, stack and permute operations.
Note that it requires input in BCHW format, or you can remove first and last permute in “forward” to make it in BHWC format.
I’ve also done the depth_to_space via this depth_to_space pytorch.
Both were tested, if you’d like to see the testing code, I can upload it as well.
class SpaceToDepth(nn.Module):
def __init__(self, block_size):
super(SpaceToDepth, self).__init__()
self.block_size = block_size
self.block_size_sq = block_size*block_size
def forward(self, input):
output = input.permute(0, 2, 3, 1)
(batch_size, s_height, s_width, s_depth) = output.size()
d_depth = s_depth * self.block_size_sq
d_width = int(s_width / self.block_size)
d_height = int(s_height / self.block_size)
t_1 = output.split(self.block_size, 2)
stack = [t_t.contiguous().view(batch_size, d_height, d_depth) for t_t in t_1]
output = torch.stack(stack, 1)
output = output.permute(0, 2, 1, 3)
output = output.permute(0, 3, 1, 2)
return output