I working with 3d image data and would like to implement transform which randomly crops 3d image. I have written a class that does this and was wondering if there is room for optimization.
class RandomCrop3D():
def __init__(self, img_sz, crop_sz):
c, h, w, d = img_sz
assert (h, w, d) > crop_sz
self.img_sz = tuple((h, w, d))
self.crop_sz = tuple(crop_sz)
def __call__(self, x):
slice_hwd = [self._get_slice(i, k) for i, k in zip(self.img_sz, self.crop_sz)]
return self._crop(x, *slice_hwd)
@staticmethod
def _get_slice(sz, crop_sz):
try :
lower_bound = torch.randint(sz-crop_sz, (1,)).item()
return lower_bound, lower_bound + crop_sz
except:
return (None, None)
@staticmethod
def _crop(x, slice_h, slice_w, slice_d):
return x[:, slice_h[0]:slice_h[1], slice_w[0]:slice_w[1], slice_d[0]:slice_d[1]]
example:
volume_3d = torch.rand(3, 100, 100, 100)
rand_crop = RandomCrop3D(volume_3d.shape, (64, 64, 64))
rand_crop(volume_3d)