Efficient way to crop 3d image in pytorch

I working with 3d image data and would like to implement transform which randomly crops 3d image. I have written a class that does this and was wondering if there is room for optimization.

class RandomCrop3D():
    def __init__(self, img_sz, crop_sz):
        c, h, w, d = img_sz
        assert (h, w, d) > crop_sz
        self.img_sz  = tuple((h, w, d))
        self.crop_sz = tuple(crop_sz)
    def __call__(self, x):
        slice_hwd = [self._get_slice(i, k) for i, k in zip(self.img_sz, self.crop_sz)]
        return self._crop(x, *slice_hwd)
    def _get_slice(sz, crop_sz):
        try : 
            lower_bound = torch.randint(sz-crop_sz, (1,)).item()
            return lower_bound, lower_bound + crop_sz
            return (None, None)
    def _crop(x, slice_h, slice_w, slice_d):
        return x[:, slice_h[0]:slice_h[1], slice_w[0]:slice_w[1], slice_d[0]:slice_d[1]]


volume_3d = torch.rand(3, 100, 100, 100)
rand_crop = RandomCrop3D(volume_3d.shape, (64, 64, 64))