Padding image with learnable parameters

Given an image of shape [C, H, W],
usually if we want to pad it with a certain value, we can use torch.nn.functional.pad(). For example,

# b c h w
image_batch = torch.zeros(2, 3, 16, 16)
pad_image_batch = F.pad(image_batch, 
                        (1, 1, 1, 1), 
                        mode='constant', 
                        value=float('-inf'))
# 2 3 18 18

However, I want to pad the image with torch.nnParameter() instead. I come up the following code,

import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange

# Channel first because it's easier to pad
# b c h w
image_batch = torch.zeros(2, 3, 16, 16)
pad_image_batch = F.pad(image_batch, 
                        (1, 1, 1, 1), 
                        mode='constant', 
                        value=float('-inf'))

# Number of visual prompt vectors x C channels
visual_prompt = nn.Parameter(torch.zeros(68, 3))
visual_prompt.data.normal_()
print(visual_prompt[0])

# Channel last because it's easier to add visual prompt vectors
pad_image_batch = rearrange(pad_image_batch, "b c h w -> b h w c")

pad_image_batch 
for image in pad_image_batch:
    counter = 0
    for i in range(18):
        for j in range(18):
            if i == 0 or i == rows - 1 or j == 0 or j == cols - 1:
                image[i][j] = visual_prompt[counter]
                counter += 1

It yields the runtime error because of this line image[i][j] = visual_prompt[counter].

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[36], line 6
      4 for j in range(18):
      5     if i == 0 or i == rows - 1 or j == 0 or j == cols - 1:
----> 6         image[i][j] = visual_prompt[counter]
      7         counter += 1

RuntimeError: A view was created in no_grad mode and is being modified inplace with grad mode enabled. 
This view is the output of a function that returns multiple views. 
Such functions do not allow the output views to be modified inplace. 
You should replace the inplace operation by an out-of-place one.

At this point, I need some advices on how to pad with learnable parameters. I feel like there are better ways than this and to avoid this runtime error.

I finally come up an efficient way by using torch.cat()

import torch
import torch.nn as nn
import torch.nn.functional as F

b, c, h, w = 2, 3, 16, 16
image_batch = torch.zeros(b, c, h, w)

# Put learnable parameters on left and right of the image
left = nn.Parameter(torch.zeros(1, c, h, 1))
left.data.normal_()
right = nn.Parameter(torch.zeros(1, c, h, 1))
right.data.normal_()

left = left.expand(b, -1, -1, -1)
right = right.expand(b, -1, -1, -1)

pad_image_batch = torch.cat([left, image_batch, right], dim=3)

# Put learnable parameters on top and bottom of the image
top = nn.Parameter(torch.zeros(1, c, 1, w+2))
top.data.normal_()
bottom = nn.Parameter(torch.zeros(1, c, 1, w+2))
bottom.data.normal_()

top = top.expand(b, -1, -1, -1)
bottom = top.expand(b, -1, -1, -1)

pad_image_batch = torch.cat([top, pad_image_batch, bottom], dim=2)