I would like to reshape my data, which represents a single image, such that the width and height (dims 2,3) are lowered and the batch size increases; effectively making small crops as batches out of an image.
I want a specific dimension, the channels (1) to contain the same data.
I initially wrote a pair of functions to do something like this, but I fear it’s not backprop compatible
def img_to_batch(imgtensor, patch_size: int):
_, ch, height, width = imgtensor.shape
assert height%patch_size == 0 and width%patch_size == 0, 'img_to_batch: dims must be dividable by patch_size. {}%{}!=0'.format(imgtensor.shape, patch_size)
bs = math.ceil(height/patch_size) * math.ceil(width/patch_size)
btensor = torch.zeros([bs,ch,patch_size, patch_size], device=imgtensor.device, dtype=imgtensor.dtype)
xstart = ystart = 0
for i in range(bs):
btensor[i] = imgtensor[:, :, ystart:ystart+patch_size, xstart:xstart+patch_size]
xstart += patch_size
if xstart+patch_size > width:
xstart = 0
ystart += patch_size
return btensor
def batch_to_img(btensor, height: int, width: int, ch=3):
imgtensor = torch.zeros([1, ch, height, width], device=btensor.device, dtype=btensor.dtype)
patch_size = btensor.shape[-1]
xstart = ystart = 0
for i in range(btensor.size(0)):
imgtensor[0, :, ystart:ystart+patch_size, xstart:xstart+patch_size] = btensor[i]
xstart += patch_size
if xstart+patch_size > width:
xstart = 0
ystart += patch_size
return imgtensor
There is a simple view / reshape function in the torch library, but when I use it the channels do not keep their respective data. eg (I would like the data shown to contain the same elements even if out of order):
>>> imgtens = torch.rand(1,3,4,4)
>>> imgtens[:,0,:,:]
tensor([[[0.6830, 0.2091, 0.8786, 0.6002],
[0.0325, 0.7217, 0.1479, 0.3478],
[0.0880, 0.8705, 0.0929, 0.7978],
[0.7604, 0.2658, 0.3518, 0.1969]]])
>>> reshaped = imgtens.view([4,3,2,2])
>>> reshaped[:,0,:,:]
tensor([[[0.6830, 0.2091],
[0.8786, 0.6002]],
[[0.7604, 0.2658],
[0.3518, 0.1969]],
[[0.3787, 0.0042],
[0.3481, 0.2722]],
[[0.4175, 0.8700],
[0.1930, 0.7646]]])
I read that the fold/unfold function may be able to help by creating a sliding window.