No-Overlap Sliding Windows for 3D Data, but overlap for Edge Cases

So I have a series of tensors of size 96x440x440, and I am think of turn them into chunks of 96x96x96.
I would like to do a sliding windows on the last 2 axis.

Using a simple tensor of 3x3 for example

1 , 2 , 3
4, 5, 6
7, 8, 9

I would like to get 2x2 chunks… oh… patches.
The result should be:

1,2
4,5

2,3
5,6

4,5
7,8

5,6
8,9

I am using a sliding window of size 2,2 here, if we pad it we would get result like:
3,0
6,0

But for Edge Cases like it, I want it to “overlap” back to parts that is already used in other patches (chunks of 3D)

I can already implement this using torch indexing techniques with a loop and some if conditions,
But I wonder is it possible to do it with PyTorch build-ins like unfold?

Thx in advance!

My Description can be a little bewildering, to put it in code it would be:


def DDD_from_DHW(DHW):
    """Producing chunks with size of dimension x dimension x dimension
    from dimension x height x width"""
    size = DHW.shape
    depth_of_dimension = size[0]
    height = size[1]
    width = size[2]
    num_of_full_chunks_at_height = int(height / depth_of_dimension)
    num_of_full_chunks_at_width = (int
                                   (width / depth_of_dimension))
    no_round_height =  height % depth_of_dimension
    no_round_width =  width % depth_of_dimension
    chunk_of_tensors = []

    for each_index_of_height in range(num_of_full_chunks_at_height + 1):
        for each_index_of_width in range(num_of_full_chunks_at_width + 1):
            if each_index_of_width == num_of_full_chunks_at_width and each_index_of_height != num_of_full_chunks_at_height and no_round_width:
                """This is the case of height reaching edge case"""
                chunk_of_tensors.append(DHW[:, each_index_of_height * depth_of_dimension: (each_index_of_height + 1) * depth_of_dimension, -1 * depth_of_dimension:])
            elif each_index_of_height == num_of_full_chunks_at_height and each_index_of_width != num_of_full_chunks_at_width and no_round_height:
                """This is the case of width reaching edge"""
                chunk_of_tensors.append(DHW[:, -1 * depth_of_dimension:, each_index_of_width * depth_of_dimension: (each_index_of_width + 1) * depth_of_dimension])
            elif each_index_of_height == num_of_full_chunks_at_height and each_index_of_width == num_of_full_chunks_at_width and no_round_width and no_round_height:
                """This is the case of reaching end corner case"""
                chunk_of_tensors.append(DHW[:, -1 * depth_of_dimension:, -1 * depth_of_dimension:])
            elif (each_index_of_height + 1) * depth_of_dimension < height + 1and (
                    each_index_of_width + 1) * depth_of_dimension < width + 1:
                chunk_of_tensors.append(DHW[:, each_index_of_height * depth_of_dimension: (each_index_of_height + 1) * depth_of_dimension, each_index_of_width * depth_of_dimension: (each_index_of_width + 1) * depth_of_dimension])
    return chunk_of_tensors

My code for patching and recon looks like this:


import torch

def DDD_from_DHW(DHW):
    """Producing chunks with size of dimension x dimension x dimension
    from dimension x height x width"""
    size = DHW.shape
    depth_of_dimension = size[0]
    height = size[1]
    width = size[2]
    num_of_full_chunks_at_height = int(height / depth_of_dimension)
    num_of_full_chunks_at_width = int(width / depth_of_dimension)
    no_round_height = height % depth_of_dimension
    no_round_width = width % depth_of_dimension
    chunk_of_tensors = []

    for each_index_of_height in range(num_of_full_chunks_at_height + 1):
        for each_index_of_width in range(num_of_full_chunks_at_width + 1):
            if (each_index_of_width == num_of_full_chunks_at_width and
                    each_index_of_height != num_of_full_chunks_at_height and
                    no_round_width):
                """This is the case of height reaching edge case"""
                chunk_of_tensors.append(DHW[:,
                                        each_index_of_height * depth_of_dimension: (
                                                                                               each_index_of_height + 1) * depth_of_dimension,
                                        -1 * depth_of_dimension:])
            elif (each_index_of_height == num_of_full_chunks_at_height and
                  each_index_of_width != num_of_full_chunks_at_width and
                  no_round_height):
                """This is the case of width reaching edge"""
                chunk_of_tensors.append(DHW[:,
                                        -1 * depth_of_dimension:,
                                        each_index_of_width * depth_of_dimension: (
                                                                                              each_index_of_width + 1) * depth_of_dimension])
            elif (each_index_of_height == num_of_full_chunks_at_height and
                  each_index_of_width == num_of_full_chunks_at_width and
                  no_round_width and
                  no_round_height):
                """This is the case of reaching end corner case"""
                chunk_of_tensors.append(DHW[:,
                                        -1 * depth_of_dimension:,
                                        -1 * depth_of_dimension:])
            elif ((each_index_of_height + 1) * depth_of_dimension < height + 1 and
                  (each_index_of_width + 1) * depth_of_dimension < width + 1):
                chunk_of_tensors.append(DHW[:,
                                        each_index_of_height * depth_of_dimension: (
                                                                                               each_index_of_height + 1) * depth_of_dimension,
                                        each_index_of_width * depth_of_dimension: (
                                                                                              each_index_of_width + 1) * depth_of_dimension])
    return chunk_of_tensors


def DHW_from_DDD(DDD, shape_of_DHW: tuple):
    depth_of_dimension = shape_of_DHW[0]
    height = shape_of_DHW[1]
    width = shape_of_DHW[2]
    num_of_full_chunks_at_height = int(height / depth_of_dimension)
    num_of_full_chunks_at_width = (int
                                   (width / depth_of_dimension))
    no_round_height = height % depth_of_dimension
    no_round_width = width % depth_of_dimension

    DDD_tensor = torch.stack(DDD)
    if (not no_round_height) and (not no_round_width):
        DDD_tensor_reshape = torch.reshape(DDD_tensor, (num_of_full_chunks_at_height, num_of_full_chunks_at_width,
                                                        depth_of_dimension,
                                                        depth_of_dimension,
                                                        depth_of_dimension))
    elif no_round_height and (not no_round_width):
        DDD_tensor_reshape = torch.reshape(DDD_tensor, (num_of_full_chunks_at_height + 1, num_of_full_chunks_at_width,
                                                        depth_of_dimension,
                                                        depth_of_dimension,
                                                        depth_of_dimension))
    elif (not no_round_height) and no_round_width:
        DDD_tensor_reshape = torch.reshape(DDD_tensor, (num_of_full_chunks_at_height , num_of_full_chunks_at_width + 1,
                                                        depth_of_dimension,
                                                        depth_of_dimension,
                                                        depth_of_dimension))
    elif no_round_height and no_round_width:
        DDD_tensor_reshape = torch.reshape(DDD_tensor, (
            num_of_full_chunks_at_height + 1, num_of_full_chunks_at_width + 1,
            depth_of_dimension,
            depth_of_dimension,
            depth_of_dimension))
    else:
        raise Exception(
            f"Weird combination of condition paired: no_round_height {no_round_height} no_round_width {no_round_width}")

    tensor_placeholder = torch.zeros(shape_of_DHW)
    for each_index_of_height in range(num_of_full_chunks_at_height + 1):
        for each_index_of_width in range(num_of_full_chunks_at_width + 1):
            if each_index_of_width == num_of_full_chunks_at_width and each_index_of_height != num_of_full_chunks_at_height and no_round_width:
                """This is the case of height reaching edge case"""
                tensor_placeholder[:,
                each_index_of_height * depth_of_dimension: (each_index_of_height + 1) * depth_of_dimension,
                -1 * depth_of_dimension:] = \
                    DDD_tensor_reshape[each_index_of_height, each_index_of_width ]
            elif each_index_of_height == num_of_full_chunks_at_height and each_index_of_width != num_of_full_chunks_at_width and no_round_height:
                """This is the case of width reaching edge"""
                tensor_placeholder[:, -1 * depth_of_dimension:,
                each_index_of_width * depth_of_dimension: (each_index_of_width + 1) * depth_of_dimension] = \
                    DDD_tensor_reshape[each_index_of_height , each_index_of_width]
            elif each_index_of_height == num_of_full_chunks_at_height and each_index_of_width == num_of_full_chunks_at_width and no_round_width and no_round_height:
                """This is the case of reaching end corner case"""
                tensor_placeholder[:, -1 * depth_of_dimension:, -1 * depth_of_dimension:] = DDD_tensor_reshape[
                    each_index_of_height , each_index_of_width]
            elif (each_index_of_height + 1) * depth_of_dimension < height + 1 and (
                    each_index_of_width + 1) * depth_of_dimension < width + 1:
                tensor_placeholder[:,
                each_index_of_height * depth_of_dimension: (each_index_of_height + 1) * depth_of_dimension,
                each_index_of_width * depth_of_dimension: (each_index_of_width + 1) * depth_of_dimension] = \
                    DDD_tensor_reshape[each_index_of_height, each_index_of_width]

    return tensor_placeholder


input = torch.rand(3, 7, 6)
print(input)
output = DDD_from_DHW(input)
print(output)
recon = DHW_from_DDD(output, input.shape)
print(recon)
print("Match Up Result:")
print(input == recon)