Padding image with learnable parameters

I finally come up an efficient way by using torch.cat()

import torch
import torch.nn as nn
import torch.nn.functional as F

b, c, h, w = 2, 3, 16, 16
image_batch = torch.zeros(b, c, h, w)

# Put learnable parameters on left and right of the image
left = nn.Parameter(torch.zeros(1, c, h, 1))
left.data.normal_()
right = nn.Parameter(torch.zeros(1, c, h, 1))
right.data.normal_()

left = left.expand(b, -1, -1, -1)
right = right.expand(b, -1, -1, -1)

pad_image_batch = torch.cat([left, image_batch, right], dim=3)

# Put learnable parameters on top and bottom of the image
top = nn.Parameter(torch.zeros(1, c, 1, w+2))
top.data.normal_()
bottom = nn.Parameter(torch.zeros(1, c, 1, w+2))
bottom.data.normal_()

top = top.expand(b, -1, -1, -1)
bottom = top.expand(b, -1, -1, -1)

pad_image_batch = torch.cat([top, pad_image_batch, bottom], dim=2)