My code for patching and recon looks like this:
import torch
def DDD_from_DHW(DHW):
"""Producing chunks with size of dimension x dimension x dimension
from dimension x height x width"""
size = DHW.shape
depth_of_dimension = size[0]
height = size[1]
width = size[2]
num_of_full_chunks_at_height = int(height / depth_of_dimension)
num_of_full_chunks_at_width = int(width / depth_of_dimension)
no_round_height = height % depth_of_dimension
no_round_width = width % depth_of_dimension
chunk_of_tensors = []
for each_index_of_height in range(num_of_full_chunks_at_height + 1):
for each_index_of_width in range(num_of_full_chunks_at_width + 1):
if (each_index_of_width == num_of_full_chunks_at_width and
each_index_of_height != num_of_full_chunks_at_height and
no_round_width):
"""This is the case of height reaching edge case"""
chunk_of_tensors.append(DHW[:,
each_index_of_height * depth_of_dimension: (
each_index_of_height + 1) * depth_of_dimension,
-1 * depth_of_dimension:])
elif (each_index_of_height == num_of_full_chunks_at_height and
each_index_of_width != num_of_full_chunks_at_width and
no_round_height):
"""This is the case of width reaching edge"""
chunk_of_tensors.append(DHW[:,
-1 * depth_of_dimension:,
each_index_of_width * depth_of_dimension: (
each_index_of_width + 1) * depth_of_dimension])
elif (each_index_of_height == num_of_full_chunks_at_height and
each_index_of_width == num_of_full_chunks_at_width and
no_round_width and
no_round_height):
"""This is the case of reaching end corner case"""
chunk_of_tensors.append(DHW[:,
-1 * depth_of_dimension:,
-1 * depth_of_dimension:])
elif ((each_index_of_height + 1) * depth_of_dimension < height + 1 and
(each_index_of_width + 1) * depth_of_dimension < width + 1):
chunk_of_tensors.append(DHW[:,
each_index_of_height * depth_of_dimension: (
each_index_of_height + 1) * depth_of_dimension,
each_index_of_width * depth_of_dimension: (
each_index_of_width + 1) * depth_of_dimension])
return chunk_of_tensors
def DHW_from_DDD(DDD, shape_of_DHW: tuple):
depth_of_dimension = shape_of_DHW[0]
height = shape_of_DHW[1]
width = shape_of_DHW[2]
num_of_full_chunks_at_height = int(height / depth_of_dimension)
num_of_full_chunks_at_width = (int
(width / depth_of_dimension))
no_round_height = height % depth_of_dimension
no_round_width = width % depth_of_dimension
DDD_tensor = torch.stack(DDD)
if (not no_round_height) and (not no_round_width):
DDD_tensor_reshape = torch.reshape(DDD_tensor, (num_of_full_chunks_at_height, num_of_full_chunks_at_width,
depth_of_dimension,
depth_of_dimension,
depth_of_dimension))
elif no_round_height and (not no_round_width):
DDD_tensor_reshape = torch.reshape(DDD_tensor, (num_of_full_chunks_at_height + 1, num_of_full_chunks_at_width,
depth_of_dimension,
depth_of_dimension,
depth_of_dimension))
elif (not no_round_height) and no_round_width:
DDD_tensor_reshape = torch.reshape(DDD_tensor, (num_of_full_chunks_at_height , num_of_full_chunks_at_width + 1,
depth_of_dimension,
depth_of_dimension,
depth_of_dimension))
elif no_round_height and no_round_width:
DDD_tensor_reshape = torch.reshape(DDD_tensor, (
num_of_full_chunks_at_height + 1, num_of_full_chunks_at_width + 1,
depth_of_dimension,
depth_of_dimension,
depth_of_dimension))
else:
raise Exception(
f"Weird combination of condition paired: no_round_height {no_round_height} no_round_width {no_round_width}")
tensor_placeholder = torch.zeros(shape_of_DHW)
for each_index_of_height in range(num_of_full_chunks_at_height + 1):
for each_index_of_width in range(num_of_full_chunks_at_width + 1):
if each_index_of_width == num_of_full_chunks_at_width and each_index_of_height != num_of_full_chunks_at_height and no_round_width:
"""This is the case of height reaching edge case"""
tensor_placeholder[:,
each_index_of_height * depth_of_dimension: (each_index_of_height + 1) * depth_of_dimension,
-1 * depth_of_dimension:] = \
DDD_tensor_reshape[each_index_of_height, each_index_of_width ]
elif each_index_of_height == num_of_full_chunks_at_height and each_index_of_width != num_of_full_chunks_at_width and no_round_height:
"""This is the case of width reaching edge"""
tensor_placeholder[:, -1 * depth_of_dimension:,
each_index_of_width * depth_of_dimension: (each_index_of_width + 1) * depth_of_dimension] = \
DDD_tensor_reshape[each_index_of_height , each_index_of_width]
elif each_index_of_height == num_of_full_chunks_at_height and each_index_of_width == num_of_full_chunks_at_width and no_round_width and no_round_height:
"""This is the case of reaching end corner case"""
tensor_placeholder[:, -1 * depth_of_dimension:, -1 * depth_of_dimension:] = DDD_tensor_reshape[
each_index_of_height , each_index_of_width]
elif (each_index_of_height + 1) * depth_of_dimension < height + 1 and (
each_index_of_width + 1) * depth_of_dimension < width + 1:
tensor_placeholder[:,
each_index_of_height * depth_of_dimension: (each_index_of_height + 1) * depth_of_dimension,
each_index_of_width * depth_of_dimension: (each_index_of_width + 1) * depth_of_dimension] = \
DDD_tensor_reshape[each_index_of_height, each_index_of_width]
return tensor_placeholder
input = torch.rand(3, 7, 6)
print(input)
output = DDD_from_DHW(input)
print(output)
recon = DHW_from_DDD(output, input.shape)
print(recon)
print("Match Up Result:")
print(input == recon)