Make a translation on x

hello, i want to translate an image on the x axis. I have already made a function that computes translation with only one value (every item are translated with the same value) and thus the GT of the disparity is constant (at each pixel of x the image is moved by the same value).
I would like to do the same thing but with a value different at each pixel but I am struggling on how to do it.
Example:


def concat_translatex_batch(batch, translate_x_max, scale, num_times):
    # print('translate_x_max',translate_x_max)
    translate_x_max_img = translate_x_max
    ic(batch)
    translate_x_max_evt = round(translate_x_max_img / scale)

    # print('translate_x_max_evt',translate_x_max_evt)
    newbatch = {}
    x_list = []
    for key in batch:
        if key == 'input':
            # ic(key)
            input_aug = []
            for i in range(num_times):
                input_aug_i = torch.ones_like(batch[key])
                # ic(input_aug_i)
                for b in range(batch[key].shape[0]):
                    # triangular distribution
                    # half_max_transl = translate_x_max_evt // 2
                    # rest_transl = translate_x_max_evt - half_max_transl
                    # x = random
                    x = random.randint(-translate_x_max_evt,
                                       translate_x_max_evt)  # + random.randint(-half_max_transl, half_max_transl)
                    # x = torch.randint(-translate_x_max_evt,translate_x_max_evt,size=(8, 3, 1080, 1868))
                    # assert -translate_x_max_evt <= x and x <= translate_x_max_evt
                    if x > 0:
                        input_aug_i[b, :, :, :-x] = batch[key][b, :, :, x:]
                        ic(input_aug_i)
                    elif x == 0:
                        input_aug_i[b] = batch[key][b]
                    else:
                        input_aug_i[b, :, :, abs(x):] = batch[key][b, :, :, :-abs(x)]
                    x_list.append(x)
                    # ic(x_list)
                input_aug.append(input_aug_i)
                # ic(input_aug)

                # print(i,x,input_aug_i.max(),input_aug_i.min())
            newbatch[key] = torch.cat(input_aug, dim=0)[:, :, :, translate_x_max_evt:-translate_x_max_evt]
            # ic(newbatch[key])
        else:
            newbatch[key] = torch.cat(
                [batch[key][:, :, :, translate_x_max_img:-translate_x_max_img] for _ in range(num_times)], dim=0)
        
    return newbatch, x_list


Can this be achieved with a slicing operation?

a = torch.randn(1, 3, 224, 224)
shift = n
b = torch.zeros(1, 3, 224, 224)
if n >= 0:
  b[:,:,:,shift:] = a[:,:,:,0:224-shift]
else:
  b[:,:,:,0:224+shift] = a[:,:,:-shift:]